|
| 1 | +import os |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | +from pkg_resources import packaging |
| 7 | + |
| 8 | +from ts.torch_handler.image_classifier import ImageClassifier |
| 9 | +from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext |
| 10 | +from ts.utils.util import load_label_mapping |
| 11 | +from ts_scripts.utils import try_and_handle |
| 12 | + |
| 13 | +CURR_FILE_PATH = Path(__file__).parent.absolute() |
| 14 | +REPO_ROOT_DIR = CURR_FILE_PATH.parents[1] |
| 15 | +EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_compile") |
| 16 | +TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg") |
| 17 | +MAPPING_DATA = REPO_ROOT_DIR.joinpath( |
| 18 | + "examples", "image_classifier", "index_to_name.json" |
| 19 | +) |
| 20 | +MODEL_PTH_FILE = "densenet161-8d451a50.pth" |
| 21 | +MODEL_FILE = REPO_ROOT_DIR.joinpath( |
| 22 | + "examples", "image_classifier", "densenet_161", "model.py" |
| 23 | +) |
| 24 | +MODEL_FILE = "model.py" |
| 25 | +MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml") |
| 26 | + |
| 27 | + |
| 28 | +PT2_AVAILABLE = ( |
| 29 | + True |
| 30 | + if packaging.version.parse(torch.__version__) > packaging.version.parse("2.0") |
| 31 | + else False |
| 32 | +) |
| 33 | + |
| 34 | +EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "plastic_bag"] |
| 35 | + |
| 36 | + |
| 37 | +@pytest.fixture |
| 38 | +def custom_working_directory(tmp_path): |
| 39 | + # Set the custom working directory |
| 40 | + custom_dir = tmp_path / "model_dir" |
| 41 | + custom_dir.mkdir() |
| 42 | + os.chdir(custom_dir) |
| 43 | + yield custom_dir |
| 44 | + # Clean up and return to the original working directory |
| 45 | + os.chdir(tmp_path) |
| 46 | + |
| 47 | + |
| 48 | +@pytest.mark.skipif(PT2_AVAILABLE == False, reason="torch version is < 2.0") |
| 49 | +def test_torch_compile_inference(monkeypatch, custom_working_directory): |
| 50 | + monkeypatch.syspath_prepend(EXAMPLE_ROOT_DIR) |
| 51 | + # Get the path to the custom working directory |
| 52 | + model_dir = custom_working_directory |
| 53 | + |
| 54 | + try_and_handle( |
| 55 | + f"wget https://download.pytorch.org/models/{MODEL_PTH_FILE} -P {model_dir}" |
| 56 | + ) |
| 57 | + |
| 58 | + # Handler for Image classification |
| 59 | + handler = ImageClassifier() |
| 60 | + |
| 61 | + # Context definition |
| 62 | + ctx = MockContext( |
| 63 | + model_pt_file=model_dir.joinpath(MODEL_PTH_FILE), |
| 64 | + model_dir=EXAMPLE_ROOT_DIR.as_posix(), |
| 65 | + model_file=MODEL_FILE, |
| 66 | + model_yaml_config_file=MODEL_YAML_CFG_FILE, |
| 67 | + ) |
| 68 | + |
| 69 | + torch.manual_seed(42 * 42) |
| 70 | + handler.initialize(ctx) |
| 71 | + handler.context = ctx |
| 72 | + handler.mapping = load_label_mapping(MAPPING_DATA) |
| 73 | + |
| 74 | + data = {} |
| 75 | + with open(TEST_DATA, "rb") as image: |
| 76 | + image_file = image.read() |
| 77 | + byte_array_type = bytearray(image_file) |
| 78 | + data["body"] = byte_array_type |
| 79 | + |
| 80 | + result = handler.handle([data], ctx) |
| 81 | + |
| 82 | + labels = list(result[0].keys()) |
| 83 | + |
| 84 | + assert labels == EXPECTED_RESULTS |
0 commit comments