Skip to content

Commit c207cd2

Browse files
authored
Improve kserve protocol version handling (#2957)
* fix(kserve): ensure there's a default protocol configured The current implementation retrieves the protocol to use from the "PROTOCOL_VERSION" environment variable however there's no default value which will trigger an error when served through Kserve as the base class does a protocol check in the predict method that will fail with None. The default protocol uses the same value as the base class. * fix(kserve): ensure the protocol version configured is a valid value * feat(kserve): make configuration file path configurable This will allow to make the wrapper easier to test beside making it possible to to change where the file should be looked for. * test: add KServe wrapper test * test(kserve_wrapper): add protobuf code generation * fix(kserse): add None handling to the wrapper In case None is passed, keep the default value set in the base __init__. This makes TorchserveModel behave in the same fashion as the base class. * refactor(test): rewrite wrapper test to be more complete It now validates that the protocol version is passed properly as well as failure if an invalid protocol is given. * test: remove kserve pytest tests The overhead required to run the tests outweighs it benefits. The thing to keep in mind is that the fix allows to run models through kserve deployments prior to 0.11.1.
1 parent 047b91e commit c207cd2

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

kubernetes/kserve/kserve_wrapper/TorchserveModel.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ def __init__(
7474
self.inference_address = inference_address
7575
self.management_address = management_address
7676
self.model_dir = model_dir
77-
self.protocol = protocol
77+
78+
# Validate the protocol value passed otherwise, the default value will be used
79+
if protocol is not None:
80+
self.protocol = PredictorProtocol(protocol).value
7881

7982
if self.protocol == PredictorProtocol.GRPC_V2.value:
8083
self.predictor_host = grpc_inference_address

kubernetes/kserve/kserve_wrapper/__main__.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import kserve
77
from kserve.model_server import ModelServer
8-
from TorchserveModel import TorchserveModel
8+
from TorchserveModel import PredictorProtocol, TorchserveModel
99
from TSModelRepository import TSModelRepository
1010

1111
logging.basicConfig(level=kserve.constants.KSERVE_LOGLEVEL)
@@ -15,7 +15,7 @@
1515
DEFAULT_GRPC_INFERENCE_PORT = "7070"
1616

1717
DEFAULT_MODEL_STORE = "/mnt/models/model-store"
18-
CONFIG_PATH = "/mnt/models/config/config.properties"
18+
DEFAULT_CONFIG_PATH = "/mnt/models/config/config.properties"
1919

2020

2121
def parse_config():
@@ -29,8 +29,11 @@ def parse_config():
2929
"""
3030
separator = "="
3131
keys = {}
32+
config_path = os.environ.get("CONFIG_PATH", DEFAULT_CONFIG_PATH)
3233

33-
with open(CONFIG_PATH) as f:
34+
logging.info(f"Wrapper: loading configuration from {config_path}")
35+
36+
with open(config_path) as f:
3437
for line in f:
3538
if separator in line:
3639
# Find the name and value by splitting the string
@@ -99,7 +102,7 @@ def parse_config():
99102
model_dir,
100103
) = parse_config()
101104

102-
protocol = os.environ.get("PROTOCOL_VERSION")
105+
protocol = os.environ.get("PROTOCOL_VERSION", PredictorProtocol.REST_V1.value)
103106

104107
models = []
105108

0 commit comments

Comments
 (0)