|
1 | 1 | """ The torchserve side inference end-points request are handled to
|
2 | 2 | return a KServe side response """
|
3 | 3 | import logging
|
| 4 | +import os |
4 | 5 | import pathlib
|
| 6 | +import time |
5 | 7 | from enum import Enum
|
6 | 8 | from typing import Dict, Union
|
7 | 9 |
|
8 | 10 | import grpc
|
9 | 11 | import inference_pb2_grpc
|
10 | 12 | import kserve
|
| 13 | +import requests |
11 | 14 | from gprc_utils import from_ts_grpc, to_ts_grpc
|
12 | 15 | from inference_pb2 import PredictionResponse
|
13 | 16 | from kserve.errors import ModelMissingError
|
|
25 | 28 | EXPLAINER_URL_FORMAT = EXPLAINER_v2_URL_FORMAT = "http://{0}/explanations/{1}"
|
26 | 29 | REGISTER_URL_FORMAT = "{0}/models?initial_workers=1&url={1}"
|
27 | 30 | UNREGISTER_URL_FORMAT = "{0}/models/{1}"
|
| 31 | +READINESS_URL_FORMAT = "{0}/models/{1}?customized={2}" |
28 | 32 |
|
29 | 33 |
|
30 | 34 | class PredictorProtocol(Enum):
|
@@ -150,5 +154,60 @@ def load(self) -> bool:
|
150 | 154 | ]
|
151 | 155 | if len(existing_paths) == 0:
|
152 | 156 | raise ModelMissingError(model_path)
|
153 |
| - self.ready = True |
| 157 | + |
| 158 | + num_try = 0 |
| 159 | + model_load_customized = os.environ.get("MODEL_LOAD_CUSTOMIZED", "false") |
| 160 | + model_load_max_try = int(os.environ.get("MODEL_LOAD_MAX_TRY", 10)) |
| 161 | + model_load_delay = int(os.environ.get("MODEL_LOAD_DELAY", 30)) |
| 162 | + model_load_timeout = int(os.environ.get("MODEL_LOAD_TIMEOUT", 5)) |
| 163 | + while num_try < model_load_max_try and not self.ready: |
| 164 | + num_try = num_try + 1 |
| 165 | + logging.info( |
| 166 | + f"Loading {self.name} .. {num_try} of {model_load_max_try} tries.." |
| 167 | + ) |
| 168 | + |
| 169 | + try: |
| 170 | + response = requests.get( |
| 171 | + READINESS_URL_FORMAT.format( |
| 172 | + self.management_address, self.name, model_load_customized |
| 173 | + ), |
| 174 | + timeout=model_load_timeout, |
| 175 | + ).json() |
| 176 | + |
| 177 | + default_verison = response[0] |
| 178 | + |
| 179 | + workers = default_verison["workers"] |
| 180 | + workers_status = [ |
| 181 | + worker["id"] for worker in workers if worker["status"] == "READY" |
| 182 | + ] |
| 183 | + |
| 184 | + worker_ready = False |
| 185 | + if len(workers_status) > 0: |
| 186 | + worker_ready = True |
| 187 | + |
| 188 | + self.ready = ( |
| 189 | + worker_ready |
| 190 | + if model_load_customized == "false" |
| 191 | + else worker_ready and "customizedMetadata" in default_verison |
| 192 | + ) |
| 193 | + |
| 194 | + except ( |
| 195 | + requests.ConnectionError, |
| 196 | + requests.Timeout, |
| 197 | + requests.ConnectTimeout, |
| 198 | + requests.ReadTimeout, |
| 199 | + ) as e: |
| 200 | + logging.info(f"The model {self.name} is not ready") |
| 201 | + |
| 202 | + except Exception as e: |
| 203 | + logging.info(e) |
| 204 | + logging.info(f"Failed loading model {self.name}") |
| 205 | + break |
| 206 | + |
| 207 | + logging.info(f"Sleep {model_load_delay} seconds for load {self.name}..") |
| 208 | + time.sleep(model_load_delay) |
| 209 | + |
| 210 | + if self.ready: |
| 211 | + logging.info(f"The model {self.name} is ready") |
| 212 | + |
154 | 213 | return self.ready
|
0 commit comments