Skip to content

Commit 7c4559c

Browse files
authored
Fix set model state if runtime is null (#2928)
* Fix empty runtime issue + add test * Remove creation of mnist.mar cpp test file * Fix format error
1 parent cbd7d77 commit 7c4559c

File tree

3 files changed

+147
-73
lines changed

3 files changed

+147
-73
lines changed

frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.pytorch.serve.wlm;
22

3+
import com.google.gson.JsonElement;
34
import com.google.gson.JsonObject;
45
import java.io.File;
56
import java.util.Collections;
@@ -182,7 +183,12 @@ public void setModelState(JsonObject modelInfo) {
182183
maxBatchDelay = modelInfo.get(MAX_BATCH_DELAY).getAsInt();
183184
responseTimeout = modelInfo.get(RESPONSE_TIMEOUT).getAsInt();
184185
batchSize = modelInfo.get(BATCH_SIZE).getAsInt();
185-
runtimeType = Manifest.RuntimeType.fromValue(modelInfo.get(RUNTIME_TYPE).getAsString());
186+
187+
JsonElement runtime = modelInfo.get(RUNTIME_TYPE);
188+
String runtime_str = Manifest.RuntimeType.PYTHON.getValue();
189+
if (runtime != null) runtime_str = runtime.getAsString();
190+
191+
runtimeType = Manifest.RuntimeType.fromValue(runtime_str);
186192
if (modelInfo.get(PARALLEL_LEVEL) != null) {
187193
parallelLevel = modelInfo.get(PARALLEL_LEVEL).getAsInt();
188194
}

test/pytest/test_snapshot.py

+140-65
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import glob
2+
import json
3+
import os
14
import platform
25
import time
3-
import os
4-
import glob
6+
from pathlib import Path
7+
58
import requests
6-
import json
79
import test_utils
810

911

@@ -16,7 +18,11 @@ def teardown_module(module):
1618

1719

1820
def replace_mar_file_with_dummy_mar_in_model_store(model_store=None, model_mar=None):
19-
model_store = model_store if (model_store != None) else os.path.join(test_utils.ROOT_DIR, "model_store")
21+
model_store = (
22+
model_store
23+
if (model_store != None)
24+
else os.path.join(test_utils.ROOT_DIR, "model_store")
25+
)
2026
if model_mar != None:
2127
myfilepath = os.path.join(model_store, model_mar)
2228
if os.path.exists(myfilepath):
@@ -32,20 +38,22 @@ def test_snapshot_created_on_start_and_stop():
3238
test_utils.delete_all_snapshots()
3339
test_utils.start_torchserve()
3440
test_utils.stop_torchserve()
35-
assert len(glob.glob('logs/config/*startup.cfg')) == 1
41+
assert len(glob.glob("logs/config/*startup.cfg")) == 1
3642
if platform.system() != "Windows":
37-
assert len(glob.glob('logs/config/*shutdown.cfg')) == 1
43+
assert len(glob.glob("logs/config/*shutdown.cfg")) == 1
3844

3945

4046
def snapshot_created_on_management_api_invoke(model_mar="densenet161.mar"):
4147
test_utils.delete_all_snapshots()
4248
test_utils.start_torchserve()
4349
mar_path = "mar_path_{}".format(model_mar[0:-4])
4450
if mar_path in test_utils.mar_file_table:
45-
requests.post('http://127.0.0.1:8081/models?url=' + model_mar)
51+
requests.post("http://127.0.0.1:8081/models?url=" + model_mar)
4652
else:
47-
requests.post('http://127.0.0.1:8081/models?url=https://torchserve.pytorch.org/mar_files/'
48-
+ model_mar)
53+
requests.post(
54+
"http://127.0.0.1:8081/models?url=https://torchserve.pytorch.org/mar_files/"
55+
+ model_mar
56+
)
4957
time.sleep(10)
5058
test_utils.stop_torchserve()
5159

@@ -55,17 +63,17 @@ def test_snapshot_created_on_management_api_invoke():
5563
Validates that snapshot.cfg is created when management apis are invoked.
5664
"""
5765
snapshot_created_on_management_api_invoke()
58-
assert len(glob.glob('logs/config/*snap*.cfg')) == 1
66+
assert len(glob.glob("logs/config/*snap*.cfg")) == 1
5967

6068

6169
def test_start_from_snapshot():
6270
"""
6371
Validates if we can restore state from snapshot.
6472
"""
65-
snapshot_cfg = glob.glob('logs/config/*snap*.cfg')[0]
73+
snapshot_cfg = glob.glob("logs/config/*snap*.cfg")[0]
6674
test_utils.start_torchserve(snapshot_file=snapshot_cfg)
67-
response = requests.get('http://127.0.0.1:8081/models/')
68-
assert json.loads(response.content)['models'][0]['modelName'] == "densenet161"
75+
response = requests.get("http://127.0.0.1:8081/models/")
76+
assert json.loads(response.content)["models"][0]["modelName"] == "densenet161"
6977
test_utils.stop_torchserve()
7078

7179

@@ -74,26 +82,30 @@ def test_start_from_latest():
7482
Validates if latest snapshot file is picked if we dont pass snapshot arg explicitly.
7583
"""
7684
test_utils.start_torchserve()
77-
response = requests.get('http://127.0.0.1:8081/models/')
78-
assert json.loads(response.content)['models'][0]['modelName'] == "densenet161"
85+
response = requests.get("http://127.0.0.1:8081/models/")
86+
assert json.loads(response.content)["models"][0]["modelName"] == "densenet161"
7987
test_utils.stop_torchserve()
8088

8189

8290
def test_start_from_read_only_snapshot():
8391
"""
8492
Validates if we can start and restore Torchserve state using a read-only snapshot.
8593
"""
86-
snapshot_cfg = glob.glob('logs/config/*snap*.cfg')[0]
94+
snapshot_cfg = glob.glob("logs/config/*snap*.cfg")[0]
8795
file_status = os.stat(snapshot_cfg)
8896
os.chmod(snapshot_cfg, 0o444)
8997
test_utils.start_torchserve(snapshot_file=snapshot_cfg)
9098
os.chmod(snapshot_cfg, (file_status.st_mode & 0o777))
9199
try:
92-
response = requests.get('http://127.0.0.1:8081/models/')
100+
response = requests.get("http://127.0.0.1:8081/models/")
93101
except:
94-
assert False, "Something is not right!! Failed to start Torchserve using Read Only Snapshot!!"
102+
assert (
103+
False
104+
), "Something is not right!! Failed to start Torchserve using Read Only Snapshot!!"
95105
else:
96-
assert True, "Successfully started and restored Torchserve state using a Read Only Snapshot"
106+
assert (
107+
True
108+
), "Successfully started and restored Torchserve state using a Read Only Snapshot"
97109

98110

99111
def test_no_config_snapshots_cli_option():
@@ -105,7 +117,7 @@ def test_no_config_snapshots_cli_option():
105117
test_utils.delete_all_snapshots()
106118
test_utils.start_torchserve(no_config_snapshots=True)
107119
test_utils.stop_torchserve()
108-
assert len(glob.glob('logs/config/*.cfg')) == 0
120+
assert len(glob.glob("logs/config/*.cfg")) == 0
109121

110122

111123
def test_start_from_default():
@@ -114,8 +126,8 @@ def test_start_from_default():
114126
"""
115127
test_utils.delete_all_snapshots()
116128
test_utils.start_torchserve()
117-
response = requests.get('http://127.0.0.1:8081/models/')
118-
assert len(json.loads(response.content)['models']) == 0
129+
response = requests.get("http://127.0.0.1:8081/models/")
130+
assert len(json.loads(response.content)["models"]) == 0
119131

120132

121133
def test_start_from_non_existing_snapshot():
@@ -126,82 +138,101 @@ def test_start_from_non_existing_snapshot():
126138
test_utils.stop_torchserve()
127139
test_utils.start_torchserve(snapshot_file="logs/config/junk-snapshot.cfg")
128140
try:
129-
response = requests.get('http://127.0.0.1:8081/models/')
141+
response = requests.get("http://127.0.0.1:8081/models/")
130142
except:
131143
assert True, "Failed to start Torchserve using a Non Existing Snapshot"
132144
else:
133-
assert False, "Something is not right!! Successfully started Torchserve " \
134-
"using Non Existing Snapshot File!!"
145+
assert False, (
146+
"Something is not right!! Successfully started Torchserve "
147+
"using Non Existing Snapshot File!!"
148+
)
135149

136150

137151
def test_torchserve_init_with_non_existent_model_store():
138-
"""Validates that Torchserve fails to start if the model store directory is non existent """
152+
"""Validates that Torchserve fails to start if the model store directory is non existent"""
139153

140-
test_utils.start_torchserve(model_store="/invalid_model_store", snapshot_file=None, no_config_snapshots=True)
154+
test_utils.start_torchserve(
155+
model_store="/invalid_model_store", snapshot_file=None, no_config_snapshots=True
156+
)
141157
try:
142-
response = requests.get('http://127.0.0.1:8081/models/')
158+
response = requests.get("http://127.0.0.1:8081/models/")
143159
except:
144-
assert True, "Failed to start Torchserve using non existent model-store directory"
160+
assert (
161+
True
162+
), "Failed to start Torchserve using non existent model-store directory"
145163
else:
146-
assert False, "Something is not right!! Successfully started Torchserve " \
147-
"using non existent directory!!"
164+
assert False, (
165+
"Something is not right!! Successfully started Torchserve "
166+
"using non existent directory!!"
167+
)
148168
finally:
149169
test_utils.delete_model_store()
150170
test_utils.delete_all_snapshots()
151171

152172

153173
def test_restart_torchserve_with_last_snapshot_with_model_mar_removed():
154174
"""Validates that torchserve will fail to start in the following scenario:
155-
1) We use a snapshot file to start torchserve. The snapshot contains reference to "A" model file
156-
2) The "A" model mar file is accidentally deleted from the model store"""
175+
1) We use a snapshot file to start torchserve. The snapshot contains reference to "A" model file
176+
2) The "A" model mar file is accidentally deleted from the model store"""
157177

158178
# Register model using mgmt api
159179
snapshot_created_on_management_api_invoke()
160180

161181
# Now remove the registered model mar file (delete_mar_ fn)
162-
test_utils.delete_mar_file_from_model_store(model_store=os.path.join(test_utils.ROOT_DIR, "model_store"),
163-
model_mar="densenet")
182+
test_utils.delete_mar_file_from_model_store(
183+
model_store=os.path.join(test_utils.ROOT_DIR, "model_store"),
184+
model_mar="densenet",
185+
)
164186

165187
# Start Torchserve with last generated snapshot file
166-
snapshot_cfg = glob.glob('logs/config/*snap*.cfg')[0]
188+
snapshot_cfg = glob.glob("logs/config/*snap*.cfg")[0]
167189
test_utils.start_torchserve(snapshot_file=snapshot_cfg, gen_mar=False)
168190
try:
169-
response = requests.get('http://127.0.0.1:8081/models/')
191+
response = requests.get("http://127.0.0.1:8081/models/")
170192
except:
171-
assert True, "Failed to start Torchserve properly as reqd model mar file is missing!!"
193+
assert (
194+
True
195+
), "Failed to start Torchserve properly as reqd model mar file is missing!!"
172196
else:
173-
assert False, "Something is not right!! Successfully started Torchserve without reqd mar file"
197+
assert (
198+
False
199+
), "Something is not right!! Successfully started Torchserve without reqd mar file"
174200
finally:
175201
test_utils.delete_model_store()
176202
test_utils.delete_all_snapshots()
177203

178204

179205
def test_replace_mar_file_with_dummy():
180206
"""Validates that torchserve will fail to start in the following scenario:
181-
1) We use a snapshot file to start torchserve. The snapshot contains reference to "A" model file
182-
2) "A" model file gets corrupted or is replaced by some dummy mar file with same name"""
207+
1) We use a snapshot file to start torchserve. The snapshot contains reference to "A" model file
208+
2) "A" model file gets corrupted or is replaced by some dummy mar file with same name
209+
"""
183210

184211
snapshot_created_on_management_api_invoke()
185212

186213
# Start Torchserve using last snapshot state
187-
snapshot_cfg = glob.glob('logs/config/*snap*.cfg')[0]
214+
snapshot_cfg = glob.glob("logs/config/*snap*.cfg")[0]
188215
test_utils.start_torchserve(snapshot_file=snapshot_cfg)
189-
response = requests.get('http://127.0.0.1:8081/models/')
190-
assert json.loads(response.content)['models'][0]['modelName'] == "densenet161"
216+
response = requests.get("http://127.0.0.1:8081/models/")
217+
assert json.loads(response.content)["models"][0]["modelName"] == "densenet161"
191218
test_utils.stop_torchserve()
192219

193220
# Now replace the registered model mar with dummy file
194221
replace_mar_file_with_dummy_mar_in_model_store(
195-
model_store=os.path.join(test_utils.ROOT_DIR, "model_store"), model_mar="densenet161.mar")
196-
snapshot_cfg = glob.glob('logs/config/*snap*.cfg')[0]
222+
model_store=os.path.join(test_utils.ROOT_DIR, "model_store"),
223+
model_mar="densenet161.mar",
224+
)
225+
snapshot_cfg = glob.glob("logs/config/*snap*.cfg")[0]
197226
test_utils.start_torchserve(snapshot_file=snapshot_cfg, gen_mar=False)
198227
try:
199-
response = requests.get('http://127.0.0.1:8081/models/')
200-
assert json.loads(response.content)['models'][0]['modelName'] == "densenet161"
228+
response = requests.get("http://127.0.0.1:8081/models/")
229+
assert json.loads(response.content)["models"][0]["modelName"] == "densenet161"
201230
except:
202231
assert False, "Default manifest does not work"
203232
else:
204-
assert True, "Successfully started Torchserve with a dummy mar file (ie. default manifest)"
233+
assert (
234+
True
235+
), "Successfully started Torchserve with a dummy mar file (ie. default manifest)"
205236
finally:
206237
test_utils.unregister_model("densenet161")
207238
test_utils.delete_all_snapshots()
@@ -211,41 +242,85 @@ def test_replace_mar_file_with_dummy():
211242

212243
def test_restart_torchserve_with_one_of_model_mar_removed():
213244
"""Validates that torchserve will fail to start in the following scenario:
214-
1) We use a snapshot file to start torchserve. The snapshot contains reference to few model files
215-
2) One of these model mar files are accidentally deleted from the model store"""
245+
1) We use a snapshot file to start torchserve. The snapshot contains reference to few model files
246+
2) One of these model mar files are accidentally deleted from the model store"""
216247
# Register multiple models
217248
# 1st model
218249
test_utils.delete_model_store()
219250
test_utils.start_torchserve()
220-
requests.post(
221-
'http://127.0.0.1:8081/models?url=densenet161.mar')
251+
requests.post("http://127.0.0.1:8081/models?url=densenet161.mar")
222252
time.sleep(15)
223253
# 2nd model
224-
requests.post(
225-
'http://127.0.0.1:8081/models?url=mnist.mar')
254+
requests.post("http://127.0.0.1:8081/models?url=mnist.mar")
226255
time.sleep(15)
227256
test_utils.stop_torchserve()
228257

229258
# Start Torchserve
230259
test_utils.start_torchserve()
231-
response = requests.get('http://127.0.0.1:8081/models/')
232-
num_of_regd_models = len(json.loads(response.content)['models'])
260+
response = requests.get("http://127.0.0.1:8081/models/")
261+
num_of_regd_models = len(json.loads(response.content)["models"])
233262
test_utils.stop_torchserve()
234263

235264
# Now remove the registered model mar file (delete_mar_ fn)
236-
test_utils.delete_mar_file_from_model_store(model_store=os.path.join(test_utils.ROOT_DIR, "model_store"),
237-
model_mar="densenet")
265+
test_utils.delete_mar_file_from_model_store(
266+
model_store=os.path.join(test_utils.ROOT_DIR, "model_store"),
267+
model_mar="densenet",
268+
)
238269

239270
# Start Torchserve with existing snapshot file containing reference to one of the model mar file
240271
# which is now missing from the model store
241-
snapshot_cfg = glob.glob('logs/config/*snap*.cfg')[1]
272+
snapshot_cfg = glob.glob("logs/config/*snap*.cfg")[1]
242273
test_utils.start_torchserve(snapshot_file=snapshot_cfg, gen_mar=False)
243274
try:
244-
response = requests.get('http://127.0.0.1:8081/models/')
275+
response = requests.get("http://127.0.0.1:8081/models/")
245276
except:
246-
assert True, "Failed to start Torchserve as one of reqd model mar file is missing"
277+
assert (
278+
True
279+
), "Failed to start Torchserve as one of reqd model mar file is missing"
247280
else:
248-
assert False, "Something is not right!! Started Torchserve successfully with a " \
249-
"reqd model mar file missing from the model store!!"
281+
assert False, (
282+
"Something is not right!! Started Torchserve successfully with a "
283+
"reqd model mar file missing from the model store!!"
284+
)
250285
finally:
251-
test_utils.torchserve_cleanup()
286+
test_utils.torchserve_cleanup()
287+
288+
289+
def test_empty_runtime():
290+
test_utils.delete_all_snapshots()
291+
test_utils.stop_torchserve()
292+
test_utils.start_torchserve()
293+
requests.post("http://127.0.0.1:8081/models?url=mnist.mar")
294+
test_utils.stop_torchserve()
295+
296+
cfgs = glob.glob("logs/config/*shutdown.cfg")
297+
assert len(cfgs) == 1
298+
299+
def remove_runtime_type(json_str):
300+
# Remove the prefix 'model_snapshot=' from the input string
301+
model_snapshot = json.loads(
302+
json_str[len("model_snapshot=") :].replace("\\:", ":").replace("\\n", "")
303+
)
304+
305+
# Remove the "runtimeType" element from the JSON object
306+
for model in model_snapshot["models"].values():
307+
for version, config in model.items():
308+
del config["runtimeType"]
309+
310+
# Return the modified JSON object as a string with the original prefix
311+
return "model_snapshot=" + json.dumps(model_snapshot, indent=2).replace(
312+
"\n", "\\n"
313+
).replace(":", "\\:")
314+
315+
cfg_text = Path(cfgs[0]).read_text().split("\n")
316+
model_snapshot = [line for line in cfg_text if line.startswith("model_snapshot")][0]
317+
cfg_text = [line for line in cfg_text if not line.startswith("model_snapshot")]
318+
cfg_text += [remove_runtime_type(model_snapshot)]
319+
Path(cfgs[0]).write_text("\n".join(cfg_text))
320+
321+
test_utils.start_torchserve()
322+
323+
try:
324+
requests.get("http://127.0.0.1:8081/models/")
325+
except:
326+
assert False, "Could not start TorchServe."

0 commit comments

Comments
 (0)