diff --git a/kubernetes/kserve/kserve_wrapper/TorchserveModel.py b/kubernetes/kserve/kserve_wrapper/TorchserveModel.py index 69c98b46cc..acf93a851d 100644 --- a/kubernetes/kserve/kserve_wrapper/TorchserveModel.py +++ b/kubernetes/kserve/kserve_wrapper/TorchserveModel.py @@ -1,13 +1,16 @@ """ The torchserve side inference end-points request are handled to return a KServe side response """ import logging +import os import pathlib +import time from enum import Enum from typing import Dict, Union import grpc import inference_pb2_grpc import kserve +import requests from gprc_utils import from_ts_grpc, to_ts_grpc from inference_pb2 import PredictionResponse from kserve.errors import ModelMissingError @@ -25,6 +28,7 @@ EXPLAINER_URL_FORMAT = EXPLAINER_v2_URL_FORMAT = "http://{0}/explanations/{1}" REGISTER_URL_FORMAT = "{0}/models?initial_workers=1&url={1}" UNREGISTER_URL_FORMAT = "{0}/models/{1}" +READINESS_URL_FORMAT = "{0}/models/{1}?customized={2}" class PredictorProtocol(Enum): @@ -150,5 +154,60 @@ def load(self) -> bool: ] if len(existing_paths) == 0: raise ModelMissingError(model_path) - self.ready = True + + num_try = 0 + model_load_customized = os.environ.get("MODEL_LOAD_CUSTOMIZED", "false") + model_load_max_try = int(os.environ.get("MODEL_LOAD_MAX_TRY", 10)) + model_load_delay = int(os.environ.get("MODEL_LOAD_DELAY", 30)) + model_load_timeout = int(os.environ.get("MODEL_LOAD_TIMEOUT", 5)) + while num_try < model_load_max_try and not self.ready: + num_try = num_try + 1 + logging.info( + f"Loading {self.name} .. {num_try} of {model_load_max_try} tries.." + ) + + try: + response = requests.get( + READINESS_URL_FORMAT.format( + self.management_address, self.name, model_load_customized + ), + timeout=model_load_timeout, + ).json() + + default_verison = response[0] + + workers = default_verison["workers"] + workers_status = [ + worker["id"] for worker in workers if worker["status"] == "READY" + ] + + worker_ready = False + if len(workers_status) > 0: + worker_ready = True + + self.ready = ( + worker_ready + if model_load_customized == "false" + else worker_ready and "customizedMetadata" in default_verison + ) + + except ( + requests.ConnectionError, + requests.Timeout, + requests.ConnectTimeout, + requests.ReadTimeout, + ) as e: + logging.info(f"The model {self.name} is not ready") + + except Exception as e: + logging.info(e) + logging.info(f"Failed loading model {self.name}") + break + + logging.info(f"Sleep {model_load_delay} seconds for load {self.name}..") + time.sleep(model_load_delay) + + if self.ready: + logging.info(f"The model {self.name} is ready") + return self.ready diff --git a/kubernetes/kserve/kserve_wrapper/__main__.py b/kubernetes/kserve/kserve_wrapper/__main__.py index f67e6de107..eee4e76259 100644 --- a/kubernetes/kserve/kserve_wrapper/__main__.py +++ b/kubernetes/kserve/kserve_wrapper/__main__.py @@ -115,6 +115,12 @@ def parse_config(): # By default model.load() is called on first request. Enabling load all # model in TS config.properties, all models are loaded at start and the # below method sets status to true for the models. + # However, even if all preparations related to loading the model (e.g., + # download pretrained models using online storage) are not completed in + # torchserve handler, if model.ready=true is set, there may be problems. + # Therefore, the ready status is determined using the api provided by + # torchserve. + model.load() models.append(model)