@@ -43,6 +43,64 @@ def gen_mar(model_store=None):
43
43
print (f"## Symlink { src } , { dst } successfully." )
44
44
45
45
46
+ def generate_model (model , model_store_dir ):
47
+ serialized_file_path = None
48
+ if model .get ("serialized_file_remote" , None ):
49
+ if model .get ("gen_scripted_file_path" , None ):
50
+ subprocess .run (["python" , model ["gen_scripted_file_path" ]])
51
+ else :
52
+ serialized_model_file_url = (
53
+ f"https://download.pytorch.org/models/{ model ['serialized_file_remote' ]} "
54
+ )
55
+ urllib .request .urlretrieve (
56
+ serialized_model_file_url ,
57
+ f'{ model_store_dir } /{ model ["serialized_file_remote" ]} ' ,
58
+ )
59
+ serialized_file_path = os .path .join (
60
+ model_store_dir , model ["serialized_file_remote" ]
61
+ )
62
+ elif model .get ("serialized_file_local" , None ):
63
+ serialized_file_path = model ["serialized_file_local" ]
64
+
65
+ handler = model .get ("handler" , None )
66
+
67
+ extra_files = model .get ("extra_files" , None )
68
+
69
+ runtime = model .get ("runtime" , None )
70
+
71
+ archive_format = model .get ("archive_format" , "zip-store" )
72
+
73
+ requirements_file = model .get ("requirements_file" , None )
74
+
75
+ export_path = model .get ("export_path" , model_store_dir )
76
+
77
+ cmd = model_archiver_command_builder (
78
+ model ["model_name" ],
79
+ model ["version" ],
80
+ model ["model_file" ],
81
+ serialized_file_path ,
82
+ handler ,
83
+ extra_files ,
84
+ runtime ,
85
+ archive_format ,
86
+ requirements_file ,
87
+ export_path ,
88
+ )
89
+ print (f"## In directory: { os .getcwd ()} | Executing command: { cmd } \n " )
90
+ try :
91
+ subprocess .check_call (cmd , shell = True )
92
+ marfile = "{}.mar" .format (model ["model_name" ])
93
+ print ("## {} is generated.\n " .format (marfile ))
94
+ mar_set .add (marfile )
95
+ except subprocess .CalledProcessError as exc :
96
+ print ("## {} creation failed !, error: {}\n " .format (model ["model_name" ], exc ))
97
+
98
+ if model .get ("serialized_file_remote" , None ) and os .path .exists (
99
+ serialized_file_path
100
+ ):
101
+ os .remove (serialized_file_path )
102
+
103
+
46
104
def generate_mars (mar_config = MAR_CONFIG_FILE_PATH , model_store_dir = MODEL_STORE_DIR ):
47
105
"""
48
106
By default generate_mars reads ts_scripts/mar_config.json and outputs mar files in dir model_store_gen
@@ -67,72 +125,7 @@ def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_D
67
125
models = json .loads (f .read ())
68
126
69
127
for model in models :
70
- serialized_file_path = None
71
- if model .get ("serialized_file_remote" ) and model ["serialized_file_remote" ]:
72
- if (
73
- model .get ("gen_scripted_file_path" )
74
- and model ["gen_scripted_file_path" ]
75
- ):
76
- subprocess .run (["python" , model ["gen_scripted_file_path" ]])
77
- else :
78
- serialized_model_file_url = (
79
- "https://download.pytorch.org/models/{}" .format (
80
- model ["serialized_file_remote" ]
81
- )
82
- )
83
- urllib .request .urlretrieve (
84
- serialized_model_file_url ,
85
- f'{ model_store_dir } /{ model ["serialized_file_remote" ]} ' ,
86
- )
87
- serialized_file_path = os .path .join (
88
- model_store_dir , model ["serialized_file_remote" ]
89
- )
90
- elif model .get ("serialized_file_local" ) and model ["serialized_file_local" ]:
91
- serialized_file_path = model ["serialized_file_local" ]
92
-
93
- handler = model .get ("handler" , None )
94
-
95
- extra_files = model .get ("extra_files" , None )
96
-
97
- runtime = model .get ("runtime" , None )
98
-
99
- archive_format = model .get ("archive_format" , "zip-store" )
100
-
101
- requirements_file = model .get ("requirements_file" , None )
102
-
103
- export_path = model .get ("export_path" , model_store_dir )
104
-
105
- cmd = model_archiver_command_builder (
106
- model ["model_name" ],
107
- model ["version" ],
108
- model ["model_file" ],
109
- serialized_file_path ,
110
- handler ,
111
- extra_files ,
112
- runtime ,
113
- archive_format ,
114
- requirements_file ,
115
- export_path ,
116
- )
117
- print (f"## In directory: { os .getcwd ()} | Executing command: { cmd } \n " )
118
- try :
119
- subprocess .check_call (cmd , shell = True )
120
- marfile = "{}.mar" .format (model ["model_name" ])
121
- print ("## {} is generated.\n " .format (marfile ))
122
- mar_set .add (marfile )
123
- except subprocess .CalledProcessError as exc :
124
- print (
125
- "## {} creation failed !, error: {}\n " .format (
126
- model ["model_name" ], exc
127
- )
128
- )
129
-
130
- if (
131
- model .get ("serialized_file_remote" )
132
- and model ["serialized_file_remote" ]
133
- and os .path .exists (serialized_file_path )
134
- ):
135
- os .remove (serialized_file_path )
128
+ generate_model (model , model_store_dir )
136
129
os .chdir (cwd )
137
130
138
131
0 commit comments