1
1
import argparse
2
2
import json
3
3
import os
4
- import sys
5
- import urllib .request
6
4
import shutil
7
5
import subprocess
6
+ import sys
7
+ import urllib .request
8
8
9
9
REPO_ROOT = os .path .join (os .path .dirname (os .path .abspath (__file__ )), ".." )
10
10
sys .path .append (REPO_ROOT )
11
11
MODEL_STORE_DIR = os .path .join (REPO_ROOT , "model_store_gen" )
12
12
os .makedirs (MODEL_STORE_DIR , exist_ok = True )
13
13
MAR_CONFIG_FILE_PATH = os .path .join (REPO_ROOT , "ts_scripts" , "mar_config.json" )
14
14
15
+
15
16
def delete_model_store_gen_dir ():
16
17
print (f"## Deleting model_store_gen_dir: { MODEL_STORE_DIR } \n " )
17
18
mar_set .clear ()
@@ -21,7 +22,10 @@ def delete_model_store_gen_dir():
21
22
except OSError as e :
22
23
print ("Error: %s : %s" % (MODEL_STORE_DIR , e .strerror ))
23
24
25
+
24
26
mar_set = set ()
27
+
28
+
25
29
def gen_mar (model_store = None ):
26
30
print (f"## Starting gen_mar: { model_store } \n " )
27
31
if len (mar_set ) == 0 :
@@ -53,7 +57,9 @@ def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_D
53
57
- "extra_files": the paths of extra files
54
58
Note: To generate .pt file, "serialized_file_remote" and "gen_scripted_file_path" must be provided
55
59
"""
56
- print (f"## Starting generate_mars, mar_config:{ mar_config } , model_store_dir:{ model_store_dir } \n " )
60
+ print (
61
+ f"## Starting generate_mars, mar_config:{ mar_config } , model_store_dir:{ model_store_dir } \n "
62
+ )
57
63
mar_set .clear ()
58
64
cwd = os .getcwd ()
59
65
os .chdir (REPO_ROOT )
@@ -63,65 +69,86 @@ def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_D
63
69
for model in models :
64
70
serialized_file_path = None
65
71
if model .get ("serialized_file_remote" ) and model ["serialized_file_remote" ]:
66
- if model .get ("gen_scripted_file_path" ) and model ["gen_scripted_file_path" ]:
72
+ if (
73
+ model .get ("gen_scripted_file_path" )
74
+ and model ["gen_scripted_file_path" ]
75
+ ):
67
76
subprocess .run (["python" , model ["gen_scripted_file_path" ]])
68
77
else :
69
- serialized_model_file_url = \
70
- "https://download.pytorch.org/models/{}" .format (model ["serialized_file_remote" ])
78
+ serialized_model_file_url = (
79
+ "https://download.pytorch.org/models/{}" .format (
80
+ model ["serialized_file_remote" ]
81
+ )
82
+ )
71
83
urllib .request .urlretrieve (
72
84
serialized_model_file_url ,
73
- f'{ model_store_dir } /{ model ["serialized_file_remote" ]} ' )
74
- serialized_file_path = os .path .join (model_store_dir , model ["serialized_file_remote" ])
85
+ f'{ model_store_dir } /{ model ["serialized_file_remote" ]} ' ,
86
+ )
87
+ serialized_file_path = os .path .join (
88
+ model_store_dir , model ["serialized_file_remote" ]
89
+ )
75
90
elif model .get ("serialized_file_local" ) and model ["serialized_file_local" ]:
76
91
serialized_file_path = model ["serialized_file_local" ]
77
92
78
- handler = None
79
- if model .get ("handler" ) and model ["handler" ]:
80
- handler = model ["handler" ]
93
+ handler = model .get ("handler" , None )
81
94
82
- extra_files = None
83
- if model .get ("extra_files" ) and model ["extra_files" ]:
84
- extra_files = model ["extra_files" ]
95
+ extra_files = model .get ("extra_files" , None )
85
96
86
- runtime = None
87
- if model .get ("runtime" ) and model ["runtime" ]:
88
- runtime = model ["runtime" ]
97
+ runtime = model .get ("runtime" , None )
89
98
90
- archive_format = None
91
- if model .get ("archive_format" ) and model ["archive_format" ]:
92
- archive_format = model ["archive_format" ]
99
+ archive_format = model .get ("archive_format" , "zip-store" )
93
100
94
- requirements_file = None
95
- if model .get ("requirements_file" ) and model ["requirements_file" ]:
96
- requirements_file = model ["requirements_file" ]
101
+ requirements_file = model .get ("requirements_file" , None )
97
102
98
- export_path = model_store_dir
99
- if model .get ("export_path" ) and model ["export_path" ]:
100
- export_path = model ["export_path" ]
103
+ export_path = model .get ("export_path" , model_store_dir )
101
104
102
- cmd = model_archiver_command_builder (model ["model_name" ], model ["version" ], model ["model_file" ],
103
- serialized_file_path , handler , extra_files ,
104
- runtime , archive_format , requirements_file , export_path )
105
+ cmd = model_archiver_command_builder (
106
+ model ["model_name" ],
107
+ model ["version" ],
108
+ model ["model_file" ],
109
+ serialized_file_path ,
110
+ handler ,
111
+ extra_files ,
112
+ runtime ,
113
+ archive_format ,
114
+ requirements_file ,
115
+ export_path ,
116
+ )
105
117
print (f"## In directory: { os .getcwd ()} | Executing command: { cmd } \n " )
106
118
try :
107
119
subprocess .check_call (cmd , shell = True )
108
120
marfile = "{}.mar" .format (model ["model_name" ])
109
121
print ("## {} is generated.\n " .format (marfile ))
110
122
mar_set .add (marfile )
111
123
except subprocess .CalledProcessError as exc :
112
- print ("## {} creation failed !, error: {}\n " .format (model ["model_name" ], exc ))
113
-
114
- if model .get ("serialized_file_remote" ) and \
115
- model ["serialized_file_remote" ] and \
116
- os .path .exists (serialized_file_path ):
124
+ print (
125
+ "## {} creation failed !, error: {}\n " .format (
126
+ model ["model_name" ], exc
127
+ )
128
+ )
129
+
130
+ if (
131
+ model .get ("serialized_file_remote" )
132
+ and model ["serialized_file_remote" ]
133
+ and os .path .exists (serialized_file_path )
134
+ ):
117
135
os .remove (serialized_file_path )
118
136
os .chdir (cwd )
119
137
120
138
121
- def model_archiver_command_builder (model_name = None , version = None , model_file = None ,
122
- serialized_file = None , handler = None , extra_files = None ,
123
- runtime = None , archive_format = None , requirements_file = None ,
124
- export_path = None , force = True ):
139
+ def model_archiver_command_builder (
140
+ model_name = None ,
141
+ version = None ,
142
+ model_file = None ,
143
+ serialized_file = None ,
144
+ handler = None ,
145
+ extra_files = None ,
146
+ runtime = None ,
147
+ archive_format = None ,
148
+ requirements_file = None ,
149
+ export_path = None ,
150
+ force = True ,
151
+ ):
125
152
cmd = "torch-model-archiver"
126
153
127
154
if model_name :
@@ -159,14 +186,21 @@ def model_archiver_command_builder(model_name=None, version=None, model_file=Non
159
186
160
187
return cmd
161
188
189
+
162
190
if __name__ == "__main__" :
163
191
# cmd:
164
192
# python ts_scripts/marsgen.py
165
193
# python ts_scripts/marsgen.py --config my_mar_config.json
166
194
167
195
parser = argparse .ArgumentParser (description = "Generate model mar files" )
168
- parser .add_argument ('--config' , default = MAR_CONFIG_FILE_PATH , help = "mar file configuration json file" )
169
- parser .add_argument ('--model-store' , default = MODEL_STORE_DIR , help = "model store dir" )
196
+ parser .add_argument (
197
+ "--config" ,
198
+ default = MAR_CONFIG_FILE_PATH ,
199
+ help = "mar file configuration json file" ,
200
+ )
201
+ parser .add_argument (
202
+ "--model-store" , default = MODEL_STORE_DIR , help = "model store dir"
203
+ )
170
204
171
205
args = parser .parse_args ()
172
206
generate_mars (args .config , MODEL_STORE_DIR )
0 commit comments