Skip to content

Commit 4dc1983

Browse files
committed
Refactored sanity tests
1 parent 70c5712 commit 4dc1983

File tree

3 files changed

+209
-157
lines changed

3 files changed

+209
-157
lines changed

torchserve_sanity.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
from ts_scripts.modelarchiver_utils import test_modelarchiver
2-
from ts_scripts.workflow_archiver_utils import test_workflow_archiver
1+
import ts_scripts.tsutils as ts
2+
from ts_scripts import marsgen as mg
33
from ts_scripts.backend_utils import test_torchserve
4+
from ts_scripts.frontend_utils import test_frontend
45
from ts_scripts.install_from_src import install_from_src
5-
from ts_scripts.sanity_utils import test_sanity
6-
from ts_scripts.sanity_utils import test_workflow_sanity
6+
from ts_scripts.modelarchiver_utils import test_modelarchiver
7+
from ts_scripts.sanity_utils import (
8+
test_markdown_files,
9+
test_sanity,
10+
test_workflow_sanity,
11+
)
712
from ts_scripts.shell_utils import rm_dir, rm_file
8-
from ts_scripts.frontend_utils import test_frontend
9-
import ts_scripts.tsutils as ts
10-
import ts_scripts.print_env_info as build_hdr_printer
11-
from ts_scripts import marsgen as mg
13+
from ts_scripts.workflow_archiver_utils import test_workflow_archiver
1214

1315

1416
def torchserve_sanity():
@@ -37,22 +39,27 @@ def torchserve_sanity():
3739
# Run workflow sanity
3840
test_workflow_sanity()
3941

42+
# Check for broken links
43+
test_markdown_files()
44+
4045
finally:
4146
cleanup()
4247

4348

4449
def cleanup():
4550
ts.stop_torchserve()
46-
rm_dir('model_store')
47-
rm_dir('logs')
51+
rm_dir("model_store")
52+
rm_dir("logs")
4853

4954
# clean up residual from model-archiver IT suite.
50-
rm_dir('model-archiver/model_archiver/htmlcov_ut model_archiver/model-archiver/htmlcov_it')
51-
rm_file('ts_scripts/*_pb2*.py', True)
55+
rm_dir(
56+
"model-archiver/model_archiver/htmlcov_ut model_archiver/model-archiver/htmlcov_it"
57+
)
58+
rm_file("ts_scripts/*_pb2*.py", True)
5259

5360
# delete mar_gen_dir
5461
mg.delete_model_store_gen_dir()
5562

5663

57-
if __name__ == '__main__':
64+
if __name__ == "__main__":
5865
torchserve_sanity()

ts_scripts/marsgen.py

+74-40
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import argparse
22
import json
33
import os
4-
import sys
5-
import urllib.request
64
import shutil
75
import subprocess
6+
import sys
7+
import urllib.request
88

99
REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
1010
sys.path.append(REPO_ROOT)
1111
MODEL_STORE_DIR = os.path.join(REPO_ROOT, "model_store_gen")
1212
os.makedirs(MODEL_STORE_DIR, exist_ok=True)
1313
MAR_CONFIG_FILE_PATH = os.path.join(REPO_ROOT, "ts_scripts", "mar_config.json")
1414

15+
1516
def delete_model_store_gen_dir():
1617
print(f"## Deleting model_store_gen_dir: {MODEL_STORE_DIR}\n")
1718
mar_set.clear()
@@ -21,7 +22,10 @@ def delete_model_store_gen_dir():
2122
except OSError as e:
2223
print("Error: %s : %s" % (MODEL_STORE_DIR, e.strerror))
2324

25+
2426
mar_set = set()
27+
28+
2529
def gen_mar(model_store=None):
2630
print(f"## Starting gen_mar: {model_store}\n")
2731
if len(mar_set) == 0:
@@ -53,7 +57,9 @@ def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_D
5357
- "extra_files": the paths of extra files
5458
Note: To generate .pt file, "serialized_file_remote" and "gen_scripted_file_path" must be provided
5559
"""
56-
print(f"## Starting generate_mars, mar_config:{mar_config}, model_store_dir:{model_store_dir}\n")
60+
print(
61+
f"## Starting generate_mars, mar_config:{mar_config}, model_store_dir:{model_store_dir}\n"
62+
)
5763
mar_set.clear()
5864
cwd = os.getcwd()
5965
os.chdir(REPO_ROOT)
@@ -63,65 +69,86 @@ def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_D
6369
for model in models:
6470
serialized_file_path = None
6571
if model.get("serialized_file_remote") and model["serialized_file_remote"]:
66-
if model.get("gen_scripted_file_path") and model["gen_scripted_file_path"]:
72+
if (
73+
model.get("gen_scripted_file_path")
74+
and model["gen_scripted_file_path"]
75+
):
6776
subprocess.run(["python", model["gen_scripted_file_path"]])
6877
else:
69-
serialized_model_file_url = \
70-
"https://download.pytorch.org/models/{}".format(model["serialized_file_remote"])
78+
serialized_model_file_url = (
79+
"https://download.pytorch.org/models/{}".format(
80+
model["serialized_file_remote"]
81+
)
82+
)
7183
urllib.request.urlretrieve(
7284
serialized_model_file_url,
73-
f'{model_store_dir}/{model["serialized_file_remote"]}')
74-
serialized_file_path = os.path.join(model_store_dir, model["serialized_file_remote"])
85+
f'{model_store_dir}/{model["serialized_file_remote"]}',
86+
)
87+
serialized_file_path = os.path.join(
88+
model_store_dir, model["serialized_file_remote"]
89+
)
7590
elif model.get("serialized_file_local") and model["serialized_file_local"]:
7691
serialized_file_path = model["serialized_file_local"]
7792

78-
handler = None
79-
if model.get("handler") and model["handler"]:
80-
handler = model["handler"]
93+
handler = model.get("handler", None)
8194

82-
extra_files = None
83-
if model.get("extra_files") and model["extra_files"]:
84-
extra_files = model["extra_files"]
95+
extra_files = model.get("extra_files", None)
8596

86-
runtime = None
87-
if model.get("runtime") and model["runtime"]:
88-
runtime = model["runtime"]
97+
runtime = model.get("runtime", None)
8998

90-
archive_format = None
91-
if model.get("archive_format") and model["archive_format"]:
92-
archive_format = model["archive_format"]
99+
archive_format = model.get("archive_format", "zip-store")
93100

94-
requirements_file = None
95-
if model.get("requirements_file") and model["requirements_file"]:
96-
requirements_file = model["requirements_file"]
101+
requirements_file = model.get("requirements_file", None)
97102

98-
export_path = model_store_dir
99-
if model.get("export_path") and model["export_path"]:
100-
export_path = model["export_path"]
103+
export_path = model.get("export_path", model_store_dir)
101104

102-
cmd = model_archiver_command_builder(model["model_name"], model["version"], model["model_file"],
103-
serialized_file_path, handler, extra_files,
104-
runtime, archive_format, requirements_file, export_path)
105+
cmd = model_archiver_command_builder(
106+
model["model_name"],
107+
model["version"],
108+
model["model_file"],
109+
serialized_file_path,
110+
handler,
111+
extra_files,
112+
runtime,
113+
archive_format,
114+
requirements_file,
115+
export_path,
116+
)
105117
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n")
106118
try:
107119
subprocess.check_call(cmd, shell=True)
108120
marfile = "{}.mar".format(model["model_name"])
109121
print("## {} is generated.\n".format(marfile))
110122
mar_set.add(marfile)
111123
except subprocess.CalledProcessError as exc:
112-
print("## {} creation failed !, error: {}\n".format(model["model_name"], exc))
113-
114-
if model.get("serialized_file_remote") and \
115-
model["serialized_file_remote"] and \
116-
os.path.exists(serialized_file_path):
124+
print(
125+
"## {} creation failed !, error: {}\n".format(
126+
model["model_name"], exc
127+
)
128+
)
129+
130+
if (
131+
model.get("serialized_file_remote")
132+
and model["serialized_file_remote"]
133+
and os.path.exists(serialized_file_path)
134+
):
117135
os.remove(serialized_file_path)
118136
os.chdir(cwd)
119137

120138

121-
def model_archiver_command_builder(model_name=None, version=None, model_file=None,
122-
serialized_file=None, handler=None, extra_files=None,
123-
runtime=None, archive_format=None, requirements_file=None,
124-
export_path=None, force=True):
139+
def model_archiver_command_builder(
140+
model_name=None,
141+
version=None,
142+
model_file=None,
143+
serialized_file=None,
144+
handler=None,
145+
extra_files=None,
146+
runtime=None,
147+
archive_format=None,
148+
requirements_file=None,
149+
export_path=None,
150+
force=True,
151+
):
125152
cmd = "torch-model-archiver"
126153

127154
if model_name:
@@ -159,14 +186,21 @@ def model_archiver_command_builder(model_name=None, version=None, model_file=Non
159186

160187
return cmd
161188

189+
162190
if __name__ == "__main__":
163191
# cmd:
164192
# python ts_scripts/marsgen.py
165193
# python ts_scripts/marsgen.py --config my_mar_config.json
166194

167195
parser = argparse.ArgumentParser(description="Generate model mar files")
168-
parser.add_argument('--config', default=MAR_CONFIG_FILE_PATH, help="mar file configuration json file")
169-
parser.add_argument('--model-store', default=MODEL_STORE_DIR, help="model store dir")
196+
parser.add_argument(
197+
"--config",
198+
default=MAR_CONFIG_FILE_PATH,
199+
help="mar file configuration json file",
200+
)
201+
parser.add_argument(
202+
"--model-store", default=MODEL_STORE_DIR, help="model store dir"
203+
)
170204

171205
args = parser.parse_args()
172206
generate_mars(args.config, MODEL_STORE_DIR)

0 commit comments

Comments
 (0)