Skip to content

Commit 6715c24

Browse files
committed
added pytest for example
1 parent 7652c1e commit 6715c24

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

examples/pt2/torch_compile/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pt2 : {backend: inductor, mode: reduce-overhead}
2727
```
2828
wget https://download.pytorch.org/models/densenet161-8d451a50.pth
2929
mkdir model_store
30-
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ../../image_classifier/densenet_161/model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler image_classifier --config-file model_config.yaml -f
30+
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ../../image_classifier/densenet_161/model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler image_classifier --config-file model-config.yaml -f
3131
```
3232

3333
#### Start TorchServe
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
from pathlib import Path
3+
4+
import pytest
5+
import torch
6+
from pkg_resources import packaging
7+
8+
from ts.torch_handler.image_classifier import ImageClassifier
9+
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
10+
from ts.utils.util import load_label_mapping
11+
from ts_scripts.utils import try_and_handle
12+
13+
CURR_FILE_PATH = Path(__file__).parent.absolute()
14+
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1]
15+
EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_compile")
16+
TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg")
17+
MAPPING_DATA = REPO_ROOT_DIR.joinpath(
18+
"examples", "image_classifier", "index_to_name.json"
19+
)
20+
MODEL_PTH_FILE = "densenet161-8d451a50.pth"
21+
MODEL_FILE = REPO_ROOT_DIR.joinpath(
22+
"examples", "image_classifier", "densenet_161", "model.py"
23+
)
24+
MODEL_FILE = "model.py"
25+
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml")
26+
27+
28+
PT2_AVAILABLE = (
29+
True
30+
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.0")
31+
else False
32+
)
33+
34+
EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "plastic_bag"]
35+
36+
37+
@pytest.fixture
38+
def custom_working_directory(tmp_path):
39+
# Set the custom working directory
40+
custom_dir = tmp_path / "model_dir"
41+
custom_dir.mkdir()
42+
os.chdir(custom_dir)
43+
yield custom_dir
44+
# Clean up and return to the original working directory
45+
os.chdir(tmp_path)
46+
47+
48+
@pytest.mark.skipif(PT2_AVAILABLE == False, reason="torch version is < 2.0")
49+
def test_torch_compile_inference(monkeypatch, custom_working_directory):
50+
monkeypatch.syspath_prepend(EXAMPLE_ROOT_DIR)
51+
# Get the path to the custom working directory
52+
model_dir = custom_working_directory
53+
54+
try_and_handle(
55+
f"wget https://download.pytorch.org/models/{MODEL_PTH_FILE} -P {model_dir}"
56+
)
57+
58+
# Handler for Image classification
59+
handler = ImageClassifier()
60+
61+
# Context definition
62+
ctx = MockContext(
63+
model_pt_file=model_dir.joinpath(MODEL_PTH_FILE),
64+
model_dir=EXAMPLE_ROOT_DIR.as_posix(),
65+
model_file=MODEL_FILE,
66+
model_yaml_config_file=MODEL_YAML_CFG_FILE,
67+
)
68+
69+
torch.manual_seed(42 * 42)
70+
handler.initialize(ctx)
71+
handler.context = ctx
72+
handler.mapping = load_label_mapping(MAPPING_DATA)
73+
74+
data = {}
75+
with open(TEST_DATA, "rb") as image:
76+
image_file = image.read()
77+
byte_array_type = bytearray(image_file)
78+
data["body"] = byte_array_type
79+
80+
result = handler.handle([data], ctx)
81+
82+
labels = list(result[0].keys())
83+
84+
assert labels == EXPECTED_RESULTS

0 commit comments

Comments
 (0)