Skip to content

Commit ead57e6

Browse files
committed
Remove wait time when stopping and starting torchserve in tests
Make sanity test run with pytest 90% Added missing test for snapshotting Use pytest tests in torchserve sanity checks
1 parent d47b14d commit ead57e6

File tree

6 files changed

+212
-114
lines changed

6 files changed

+212
-114
lines changed

test/pytest/sanity/conftest.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import json
2+
import sys
3+
from pathlib import Path
4+
5+
import pytest
6+
7+
REPO_ROOT = Path(__file__).parents[3]
8+
9+
10+
MAR_CONFIG = REPO_ROOT.joinpath("ts_scripts", "mar_config.json")
11+
12+
13+
@pytest.fixture(name="gen_models", scope="module")
14+
def load_gen_models() -> dict:
15+
with open(MAR_CONFIG) as f:
16+
models = json.load(f)
17+
models = {m["model_name"]: m for m in models}
18+
return models
19+
20+
21+
@pytest.fixture(scope="module")
22+
def ts_scripts_path():
23+
sys.path.append(REPO_ROOT.as_posix())
24+
25+
yield
26+
27+
sys.path.pop()
+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import json
2+
from pathlib import Path
3+
4+
import pytest
5+
import test_utils
6+
7+
REPO_ROOT = Path(__file__).parents[3]
8+
SANITY_MODELS_CONFIG = REPO_ROOT.joinpath("ts_scripts", "configs", "sanity_models.json")
9+
10+
11+
def load_resnet18() -> dict:
12+
with open(SANITY_MODELS_CONFIG) as f:
13+
models = json.load(f)
14+
return list(filter(lambda x: x["name"] == "resnet-18", models))[0]
15+
16+
17+
@pytest.fixture(name="resnet18")
18+
def generate_resnet18(model_store, gen_models, ts_scripts_path):
19+
model = load_resnet18()
20+
21+
from ts_scripts.marsgen import generate_model
22+
23+
generate_model(gen_models[model["name"]], model_store)
24+
25+
yield model
26+
27+
28+
@pytest.fixture(scope="module")
29+
def torchserve_with_snapshot(model_store):
30+
test_utils.torchserve_cleanup()
31+
32+
test_utils.start_torchserve(
33+
model_store=model_store, no_config_snapshots=False, gen_mar=False
34+
)
35+
36+
yield
37+
38+
test_utils.torchserve_cleanup()
39+
40+
41+
def test_config_snapshotting(
42+
resnet18, model_store, torchserve_with_snapshot, ts_scripts_path
43+
):
44+
from ts_scripts.sanity_utils import run_rest_test
45+
46+
run_rest_test(resnet18, unregister_model=False)
47+
48+
test_utils.stop_torchserve()
49+
50+
test_utils.start_torchserve(
51+
model_store=model_store, no_config_snapshots=False, gen_mar=False
52+
)
53+
54+
run_rest_test(resnet18, register_model=False)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import json
2+
from pathlib import Path
3+
4+
import pytest
5+
6+
REPO_ROOT = Path(__file__).parents[3]
7+
SANITY_MODELS_CONFIG = REPO_ROOT.joinpath("ts_scripts", "configs", "sanity_models.json")
8+
9+
10+
@pytest.fixture(scope="module")
11+
def grpc_client_stubs(ts_scripts_path):
12+
from ts_scripts.shell_utils import rm_file
13+
from ts_scripts.tsutils import generate_grpc_client_stubs
14+
15+
generate_grpc_client_stubs()
16+
17+
yield
18+
19+
rm_file(REPO_ROOT.joinpath("ts_scripts", "*_pb2*.py").as_posix(), True)
20+
21+
22+
def load_models() -> dict:
23+
with open(SANITY_MODELS_CONFIG) as f:
24+
models = json.load(f)
25+
return models
26+
27+
28+
@pytest.fixture(name="model", params=load_models(), scope="module")
29+
def models_to_validate(request, model_store, gen_models, ts_scripts_path):
30+
model = request.param
31+
32+
if model["name"] in gen_models:
33+
from ts_scripts.marsgen import generate_model
34+
35+
generate_model(gen_models[model["name"]], model_store)
36+
37+
yield model
38+
39+
40+
def test_models_with_grpc(model, torchserve, ts_scripts_path, grpc_client_stubs):
41+
from ts_scripts.sanity_utils import run_grpc_test
42+
43+
run_grpc_test(model)
44+
45+
46+
def test_models_with_rest(model, torchserve, ts_scripts_path):
47+
from ts_scripts.sanity_utils import run_rest_test
48+
49+
run_rest_test(model)
50+
51+
52+
def test_gpu_setup(ts_scripts_path):
53+
from ts_scripts.sanity_utils import test_gpu_setup
54+
55+
test_gpu_setup()

test/pytest/test_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def model_archiver_command_builder(
135135
extra_files=None,
136136
force=False,
137137
config_file=None,
138+
archive_format=None,
138139
):
139140
# Initialize a list to store the command-line arguments
140141
cmd_parts = ["torch-model-archiver"]
@@ -164,9 +165,11 @@ def model_archiver_command_builder(
164165
if force:
165166
cmd_parts.append("--force")
166167

167-
# Append the export-path argument to the list
168-
cmd_parts.append(f"--export-path {MODEL_STORE}")
168+
if archive_format:
169+
cmd += " --archive-format {0}".foramt(archive_format)
169170

171+
cmd += " --export-path {0}".format(MODEL_STORE)
172+
170173
# Convert the list into a string to represent the complete command
171174
cmd = " ".join(cmd_parts)
172175

ts_scripts/marsgen.py

+59-66
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,64 @@ def gen_mar(model_store=None):
4343
print(f"## Symlink {src}, {dst} successfully.")
4444

4545

46+
def generate_model(model, model_store_dir):
47+
serialized_file_path = None
48+
if model.get("serialized_file_remote", None):
49+
if model.get("gen_scripted_file_path", None):
50+
subprocess.run(["python", model["gen_scripted_file_path"]])
51+
else:
52+
serialized_model_file_url = (
53+
f"https://download.pytorch.org/models/{model['serialized_file_remote']}"
54+
)
55+
urllib.request.urlretrieve(
56+
serialized_model_file_url,
57+
f'{model_store_dir}/{model["serialized_file_remote"]}',
58+
)
59+
serialized_file_path = os.path.join(
60+
model_store_dir, model["serialized_file_remote"]
61+
)
62+
elif model.get("serialized_file_local", None):
63+
serialized_file_path = model["serialized_file_local"]
64+
65+
handler = model.get("handler", None)
66+
67+
extra_files = model.get("extra_files", None)
68+
69+
runtime = model.get("runtime", None)
70+
71+
archive_format = model.get("archive_format", "zip-store")
72+
73+
requirements_file = model.get("requirements_file", None)
74+
75+
export_path = model.get("export_path", model_store_dir)
76+
77+
cmd = model_archiver_command_builder(
78+
model["model_name"],
79+
model["version"],
80+
model["model_file"],
81+
serialized_file_path,
82+
handler,
83+
extra_files,
84+
runtime,
85+
archive_format,
86+
requirements_file,
87+
export_path,
88+
)
89+
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n")
90+
try:
91+
subprocess.check_call(cmd, shell=True)
92+
marfile = "{}.mar".format(model["model_name"])
93+
print("## {} is generated.\n".format(marfile))
94+
mar_set.add(marfile)
95+
except subprocess.CalledProcessError as exc:
96+
print("## {} creation failed !, error: {}\n".format(model["model_name"], exc))
97+
98+
if model.get("serialized_file_remote", None) and os.path.exists(
99+
serialized_file_path
100+
):
101+
os.remove(serialized_file_path)
102+
103+
46104
def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_DIR):
47105
"""
48106
By default generate_mars reads ts_scripts/mar_config.json and outputs mar files in dir model_store_gen
@@ -67,72 +125,7 @@ def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_D
67125
models = json.loads(f.read())
68126

69127
for model in models:
70-
serialized_file_path = None
71-
if model.get("serialized_file_remote") and model["serialized_file_remote"]:
72-
if (
73-
model.get("gen_scripted_file_path")
74-
and model["gen_scripted_file_path"]
75-
):
76-
subprocess.run(["python", model["gen_scripted_file_path"]])
77-
else:
78-
serialized_model_file_url = (
79-
"https://download.pytorch.org/models/{}".format(
80-
model["serialized_file_remote"]
81-
)
82-
)
83-
urllib.request.urlretrieve(
84-
serialized_model_file_url,
85-
f'{model_store_dir}/{model["serialized_file_remote"]}',
86-
)
87-
serialized_file_path = os.path.join(
88-
model_store_dir, model["serialized_file_remote"]
89-
)
90-
elif model.get("serialized_file_local") and model["serialized_file_local"]:
91-
serialized_file_path = model["serialized_file_local"]
92-
93-
handler = model.get("handler", None)
94-
95-
extra_files = model.get("extra_files", None)
96-
97-
runtime = model.get("runtime", None)
98-
99-
archive_format = model.get("archive_format", "zip-store")
100-
101-
requirements_file = model.get("requirements_file", None)
102-
103-
export_path = model.get("export_path", model_store_dir)
104-
105-
cmd = model_archiver_command_builder(
106-
model["model_name"],
107-
model["version"],
108-
model["model_file"],
109-
serialized_file_path,
110-
handler,
111-
extra_files,
112-
runtime,
113-
archive_format,
114-
requirements_file,
115-
export_path,
116-
)
117-
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n")
118-
try:
119-
subprocess.check_call(cmd, shell=True)
120-
marfile = "{}.mar".format(model["model_name"])
121-
print("## {} is generated.\n".format(marfile))
122-
mar_set.add(marfile)
123-
except subprocess.CalledProcessError as exc:
124-
print(
125-
"## {} creation failed !, error: {}\n".format(
126-
model["model_name"], exc
127-
)
128-
)
129-
130-
if (
131-
model.get("serialized_file_remote")
132-
and model["serialized_file_remote"]
133-
and os.path.exists(serialized_file_path)
134-
):
135-
os.remove(serialized_file_path)
128+
generate_model(model, model_store_dir)
136129
os.chdir(cwd)
137130

138131

ts_scripts/sanity_utils.py

+12-46
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from ts_scripts import marsgen as mg
1111
from ts_scripts import tsutils as ts
1212
from ts_scripts import utils
13-
from ts_scripts.tsutils import generate_grpc_client_stubs
1413

1514
REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
1615
sys.path.append(REPO_ROOT)
@@ -163,51 +162,18 @@ def run_rest_test(model, register_model=True, unregister_model=True):
163162

164163

165164
def test_sanity():
166-
generate_grpc_client_stubs()
167-
168-
print("## Started sanity tests")
169-
170-
models_to_validate = load_model_to_validate()
171-
172-
test_gpu_setup()
173-
174-
ts_log_file = os.path.join("logs", "ts_console.log")
175-
176-
os.makedirs("model_store", exist_ok=True)
177-
os.makedirs("logs", exist_ok=True)
178-
179-
mg.mar_set = set(os.listdir("model_store"))
180-
started = ts.start_torchserve(log_file=ts_log_file, gen_mar=False)
181-
if not started:
182-
sys.exit(1)
183-
184-
resnet18_model = models_to_validate["resnet-18"]
185-
186-
models_to_validate = {
187-
k: v for k, v in models_to_validate.items() if k != "resnet-18"
188-
}
189-
190-
for _, model in models_to_validate.items():
191-
run_grpc_test(model)
192-
run_rest_test(model)
193-
194-
run_rest_test(resnet18_model, unregister_model=False)
195-
196-
stopped = ts.stop_torchserve()
197-
if not stopped:
198-
sys.exit(1)
199-
200-
# Restarting torchserve
201-
# This should restart with the generated snapshot and resnet-18 model should be automatically registered
202-
started = ts.start_torchserve(log_file=ts_log_file, gen_mar=False)
203-
if not started:
204-
sys.exit(1)
205-
206-
run_rest_test(resnet18_model, register_model=False)
207-
208-
stopped = ts.stop_torchserve()
209-
if not stopped:
210-
sys.exit(1)
165+
# Execute python tests
166+
print("## Started TorchServe sanity pytests")
167+
test_dir = os.path.join("test", "pytest", "sanity")
168+
coverage_dir = os.path.join("ts")
169+
report_output_dir = os.path.join(test_dir, "coverage.xml")
170+
171+
ts_test_cmd = f"python -m pytest --cov-report xml:{report_output_dir} --cov={coverage_dir} {test_dir}"
172+
print(f"## In directory: {os.getcwd()} | Executing command: {ts_test_cmd}")
173+
ts_test_error_code = os.system(ts_test_cmd)
174+
175+
if ts_test_error_code != 0:
176+
sys.exit("## TorchServe sanity test failed !")
211177

212178

213179
def test_workflow_sanity():

0 commit comments

Comments
 (0)