Skip to content

Commit 3627ee6

Browse files
byeongjokimbyeongjo-kimmaaquibagunapal
authored
Set model status using torchserve api (#1878)
* using torchserve api when set model.ready=True * add customized describing model api if want to check handler's status * modify env name * Update TorchserveModel.py lint failure * Reformatted TorchserveModel.py --------- Co-authored-by: byeongjo-kim <[email protected]> Co-authored-by: Aaqib <[email protected]> Co-authored-by: Ankith Gunapal <[email protected]>
1 parent a797f8c commit 3627ee6

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

kubernetes/kserve/kserve_wrapper/TorchserveModel.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
""" The torchserve side inference end-points request are handled to
22
return a KServe side response """
33
import logging
4+
import os
45
import pathlib
6+
import time
57
from enum import Enum
68
from typing import Dict, Union
79

810
import grpc
911
import inference_pb2_grpc
1012
import kserve
13+
import requests
1114
from gprc_utils import from_ts_grpc, to_ts_grpc
1215
from inference_pb2 import PredictionResponse
1316
from kserve.errors import ModelMissingError
@@ -25,6 +28,7 @@
2528
EXPLAINER_URL_FORMAT = EXPLAINER_v2_URL_FORMAT = "http://{0}/explanations/{1}"
2629
REGISTER_URL_FORMAT = "{0}/models?initial_workers=1&url={1}"
2730
UNREGISTER_URL_FORMAT = "{0}/models/{1}"
31+
READINESS_URL_FORMAT = "{0}/models/{1}?customized={2}"
2832

2933

3034
class PredictorProtocol(Enum):
@@ -150,5 +154,60 @@ def load(self) -> bool:
150154
]
151155
if len(existing_paths) == 0:
152156
raise ModelMissingError(model_path)
153-
self.ready = True
157+
158+
num_try = 0
159+
model_load_customized = os.environ.get("MODEL_LOAD_CUSTOMIZED", "false")
160+
model_load_max_try = int(os.environ.get("MODEL_LOAD_MAX_TRY", 10))
161+
model_load_delay = int(os.environ.get("MODEL_LOAD_DELAY", 30))
162+
model_load_timeout = int(os.environ.get("MODEL_LOAD_TIMEOUT", 5))
163+
while num_try < model_load_max_try and not self.ready:
164+
num_try = num_try + 1
165+
logging.info(
166+
f"Loading {self.name} .. {num_try} of {model_load_max_try} tries.."
167+
)
168+
169+
try:
170+
response = requests.get(
171+
READINESS_URL_FORMAT.format(
172+
self.management_address, self.name, model_load_customized
173+
),
174+
timeout=model_load_timeout,
175+
).json()
176+
177+
default_verison = response[0]
178+
179+
workers = default_verison["workers"]
180+
workers_status = [
181+
worker["id"] for worker in workers if worker["status"] == "READY"
182+
]
183+
184+
worker_ready = False
185+
if len(workers_status) > 0:
186+
worker_ready = True
187+
188+
self.ready = (
189+
worker_ready
190+
if model_load_customized == "false"
191+
else worker_ready and "customizedMetadata" in default_verison
192+
)
193+
194+
except (
195+
requests.ConnectionError,
196+
requests.Timeout,
197+
requests.ConnectTimeout,
198+
requests.ReadTimeout,
199+
) as e:
200+
logging.info(f"The model {self.name} is not ready")
201+
202+
except Exception as e:
203+
logging.info(e)
204+
logging.info(f"Failed loading model {self.name}")
205+
break
206+
207+
logging.info(f"Sleep {model_load_delay} seconds for load {self.name}..")
208+
time.sleep(model_load_delay)
209+
210+
if self.ready:
211+
logging.info(f"The model {self.name} is ready")
212+
154213
return self.ready

kubernetes/kserve/kserve_wrapper/__main__.py

+6
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ def parse_config():
115115
# By default model.load() is called on first request. Enabling load all
116116
# model in TS config.properties, all models are loaded at start and the
117117
# below method sets status to true for the models.
118+
# However, even if all preparations related to loading the model (e.g.,
119+
# download pretrained models using online storage) are not completed in
120+
# torchserve handler, if model.ready=true is set, there may be problems.
121+
# Therefore, the ready status is determined using the api provided by
122+
# torchserve.
123+
118124
model.load()
119125
models.append(model)
120126

0 commit comments

Comments
 (0)