19
19
mnist_scriptes_py = os .path .join (REPO_ROOT , "examples/image_classifier/mnist/mnist.py" )
20
20
21
21
HANDLER_PY = """
22
+ import torch
23
+ from ts.torch_handler.base_handler import BaseHandler
24
+
25
+ class deviceHandler(BaseHandler):
26
+
27
+ def initialize(self, context):
28
+ super().initialize(context)
29
+ if torch.backends.mps.is_available() and context.system_properties.get("gpu_id") is not None:
30
+ assert self.get_device().type == "mps"
31
+ else:
32
+ assert self.get_device().type == "cpu"
33
+ """
34
+
35
+ HANDLER_PY_GPU = """
22
36
from ts.torch_handler.base_handler import BaseHandler
23
37
24
38
class deviceHandler(BaseHandler):
@@ -28,6 +42,16 @@ def initialize(self, context):
28
42
assert self.get_device().type == "mps"
29
43
"""
30
44
45
+ HANDLER_PY_CPU = """
46
+ from ts.torch_handler.base_handler import BaseHandler
47
+
48
+ class deviceHandler(BaseHandler):
49
+
50
+ def initialize(self, context):
51
+ super().initialize(context)
52
+ assert self.get_device().type == "cpu"
53
+ """
54
+
31
55
MODEL_CONFIG_YAML = """
32
56
#frontend settings
33
57
# TorchServe frontend parameters
@@ -78,8 +102,23 @@ def get_config(param):
78
102
return get_config (request .param )
79
103
80
104
105
+ @pytest .fixture (scope = "module" )
106
+ def handler_py (request ):
107
+ def get_handler (param ):
108
+ if param == "cpu" :
109
+ return HANDLER_PY_CPU
110
+ elif param == "gpu" :
111
+ return HANDLER_PY_GPU
112
+ else :
113
+ return HANDLER_PY
114
+
115
+ return get_handler (request .param )
116
+
117
+
81
118
@pytest .fixture (scope = "module" , name = "mar_file_path" )
82
- def create_mar_file (work_dir , model_archiver , model_name , model_config_name ):
119
+ def create_mar_file (
120
+ work_dir , model_archiver , model_name , model_config_name , handler_py
121
+ ):
83
122
mar_file_path = work_dir .joinpath (model_name + ".mar" )
84
123
85
124
model_config_yaml_file = work_dir / "model_config.yaml"
@@ -90,7 +129,7 @@ def create_mar_file(work_dir, model_archiver, model_name, model_config_name):
90
129
model_py_file .write_text (mnist_scriptes_py )
91
130
92
131
handler_py_file = work_dir / "handler.py"
93
- handler_py_file .write_text (HANDLER_PY )
132
+ handler_py_file .write_text (handler_py )
94
133
95
134
config = ModelArchiverConfig (
96
135
model_name = model_name ,
@@ -147,22 +186,29 @@ def register_model(mar_file_path, model_store, torchserve):
147
186
test_utils .unregister_model (model_name )
148
187
149
188
150
- @pytest .mark .skipif (platform .machine () != "arm64" , reason = "Skip on Mac M1" )
189
+ @pytest .mark .skipif (platform .machine () != "arm64" , reason = "Skip on non Mac M1" )
190
+ @pytest .mark .skipif (
191
+ os .environ .get ("TS_MAC_ARM64_CPU_ONLY" , "False" ) == "True" ,
192
+ reason = "Skip if running only on MAC CPU" ,
193
+ )
151
194
@pytest .mark .parametrize ("model_config_name" , ["gpu" ], indirect = True )
195
+ @pytest .mark .parametrize ("handler_py" , ["gpu" ], indirect = True )
152
196
def test_m1_device (model_name , model_config_name ):
153
197
response = requests .get (f"http://localhost:8081/models/{ model_name } " )
154
198
assert response .status_code == 200 , "Describe Failed"
155
199
156
200
157
- @pytest .mark .skipif (platform .machine () != "arm64" , reason = "Skip on Mac M1" )
201
+ @pytest .mark .skipif (platform .machine () != "arm64" , reason = "Skip on non Mac M1" )
158
202
@pytest .mark .parametrize ("model_config_name" , ["cpu" ], indirect = True )
203
+ @pytest .mark .parametrize ("handler_py" , ["cpu" ], indirect = True )
159
204
def test_m1_device_cpu (model_name , model_config_name ):
160
205
response = requests .get (f"http://localhost:8081/models/{ model_name } " )
161
- assert response .status_code == 404 , "Describe Worked "
206
+ assert response .status_code == 200 , "Describe Failed "
162
207
163
208
164
- @pytest .mark .skipif (platform .machine () != "arm64" , reason = "Skip on Mac M1" )
209
+ @pytest .mark .skipif (platform .machine () != "arm64" , reason = "Skip on non Mac M1" )
165
210
@pytest .mark .parametrize ("model_config_name" , ["default" ], indirect = True )
211
+ @pytest .mark .parametrize ("handler_py" , ["default" ], indirect = True )
166
212
def test_m1_device_default (model_name , model_config_name ):
167
213
response = requests .get (f"http://localhost:8081/models/{ model_name } " )
168
214
assert response .status_code == 200 , "Describe Failed"
0 commit comments