Skip to content

Commit 88eca54

Browse files
authored
torch.compile ImageClassifier example (#2915)
* updated torch.compile example * added pytest for example * lint failure * show 3x speedup with torch.compile * show 3x speedup with torch.compile * show 3x speedup with torch.compile * added missing file * added missing file * added sbin to path * skipping test * Skipping pytest for now as its causing other tests to fail
1 parent fa0f1e3 commit 88eca54

File tree

6 files changed

+210
-10
lines changed

6 files changed

+210
-10
lines changed

examples/pt2/README.md

+5-10
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
## PyTorch 2.x integration
22

3-
PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental given that most public benchmarks have focused on training instead of inference.
3+
PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption.
44

55
We strongly recommend you leverage newer hardware so for GPUs that would be an Ampere architecture. You'll get even more benefits from using server GPU deployments like A10G and A100 vs consumer cards. But you should expect to see some speedups for any Volta or Ampere architecture.
66

77
## Get started
88

99
Install torchserve and ensure that you're using at least `torch>=2.0.0`
1010

11+
To use the latest nightlies, you can run the following commands
1112
```sh
12-
python ts_scripts/install_dependencies.py --cuda=cu118
13-
pip install torchserve torch-model-archiver
13+
python ts_scripts/install_dependencies.py --cuda=cu121 --nightly_torch
14+
pip install torchserve-nightly torch-model-archiver-nightly
1415
```
1516

1617
## torch.compile
@@ -27,13 +28,7 @@ You can also pass a dictionary with compile options if you need more control ove
2728
pt2 : {backend: inductor, mode: reduce-overhead}
2829
```
2930
30-
As an example let's expand our getting started guide with the only difference being passing in the extra `model_config.yaml` file
31-
32-
```
33-
mkdir model_store
34-
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json --handler image_classifier --config-file model_config.yaml
35-
torchserve --start --ncs --model-store model_store --models densenet161.mar
36-
```
31+
An example of using `torch.compile` can be found [here](./torch_compile/README.md)
3732

3833
The exact same approach works with any other model, what's going on is the below
3934

examples/pt2/torch_compile/README.md

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
2+
# TorchServe inference with torch.compile of densenet161 model
3+
4+
This example shows how to take eager model of `densenet161`, configure TorchServe to use `torch.compile` and run inference using `torch.compile`
5+
6+
7+
### Pre-requisites
8+
9+
- `PyTorch >= 2.0`
10+
11+
Change directory to the examples directory
12+
Ex: `cd examples/pt2/torch_compile`
13+
14+
15+
### torch.compile config
16+
17+
`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html)
18+
19+
In this example , we use the following config
20+
21+
```
22+
echo "pt2 : {backend: inductor, mode: reduce-overhead}" > model-config.yaml
23+
```
24+
25+
### Create model archive
26+
27+
```
28+
wget https://download.pytorch.org/models/densenet161-8d451a50.pth
29+
mkdir model_store
30+
torch-model-archiver --model-name densenet161 --version 1.0 --model-file model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler image_classifier --config-file model-config.yaml -f
31+
```
32+
33+
#### Start TorchServe
34+
```
35+
torchserve --start --ncs --model-store model_store --models densenet161.mar
36+
```
37+
38+
#### Run Inference
39+
40+
```
41+
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg
42+
```
43+
44+
produces the output
45+
46+
```
47+
{
48+
"tabby": 0.4664836823940277,
49+
"tiger_cat": 0.4645617604255676,
50+
"Egyptian_cat": 0.06619937717914581,
51+
"lynx": 0.0012969186063855886,
52+
"plastic_bag": 0.00022856894065625966
53+
}
54+
```
55+
56+
### Performance improvement from using `torch.compile`
57+
58+
To measure the handler `preprocess`, `inference`, `postprocess` times, run the following
59+
60+
#### Measure inference time with PyTorch eager
61+
62+
```
63+
echo "handler:" > model-config.yaml && \
64+
echo " profile: true" >> model-config.yaml
65+
```
66+
67+
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
68+
After a few iterations of warmup, we see the following
69+
70+
```
71+
2024-02-03T00:54:31,136 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:6.118656158447266|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
72+
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:18.77564811706543|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
73+
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.16630400717258453|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
74+
```
75+
76+
#### Measure inference time with `torch.compile`
77+
78+
```
79+
echo "pt2: {backend: inductor, mode: reduce-overhead}" > model-config.yaml && \
80+
echo "handler:" >> model-config.yaml && \
81+
echo " profile: true" >> model-config.yaml
82+
```
83+
84+
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
85+
`torch.compile` needs a few inferences to warmup. Once warmed up, we see the following
86+
```
87+
2024-02-03T00:56:14,808 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:5.9771199226379395|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
88+
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:5.8818559646606445|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
89+
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.19392000138759613|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
90+
```
91+
92+
### Conclusion
93+
94+
`torch.compile` reduces the inference time from 18ms to 5ms
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pt2 : {backend: inductor, mode: reduce-overhead}

examples/pt2/torch_compile/model.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from torchvision.models.densenet import DenseNet
2+
3+
4+
class ImageClassifier(DenseNet):
5+
def __init__(self):
6+
super(ImageClassifier, self).__init__(48, (6, 12, 36, 24), 96)
7+
8+
def load_state_dict(self, state_dict, strict=True):
9+
# '.'s are no longer allowed in module names, but previous _DenseLayer
10+
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
11+
# They are also in the checkpoints in model_urls. This pattern is used
12+
# to find such keys.
13+
# Credit - https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py#def _load_state_dict()
14+
import re
15+
16+
pattern = re.compile(
17+
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
18+
)
19+
20+
for key in list(state_dict.keys()):
21+
res = pattern.match(key)
22+
if res:
23+
new_key = res.group(1) + res.group(2)
24+
state_dict[new_key] = state_dict[key]
25+
del state_dict[key]
26+
27+
return super(ImageClassifier, self).load_state_dict(state_dict, strict)
+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
from pathlib import Path
3+
4+
import pytest
5+
import torch
6+
from pkg_resources import packaging
7+
8+
from ts.torch_handler.image_classifier import ImageClassifier
9+
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
10+
from ts.utils.util import load_label_mapping
11+
from ts_scripts.utils import try_and_handle
12+
13+
CURR_FILE_PATH = Path(__file__).parent.absolute()
14+
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1]
15+
EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_compile")
16+
TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg")
17+
MAPPING_DATA = REPO_ROOT_DIR.joinpath(
18+
"examples", "image_classifier", "index_to_name.json"
19+
)
20+
MODEL_PTH_FILE = "densenet161-8d451a50.pth"
21+
MODEL_FILE = "model.py"
22+
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml")
23+
24+
25+
PT2_AVAILABLE = (
26+
True
27+
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.0")
28+
else False
29+
)
30+
31+
EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "plastic_bag"]
32+
33+
34+
@pytest.fixture
35+
def custom_working_directory(tmp_path):
36+
# Set the custom working directory
37+
custom_dir = tmp_path / "model_dir"
38+
custom_dir.mkdir()
39+
os.chdir(custom_dir)
40+
yield custom_dir
41+
# Clean up and return to the original working directory
42+
os.chdir(tmp_path)
43+
44+
45+
@pytest.mark.skipif(PT2_AVAILABLE == False, reason="torch version is < 2.0")
46+
@pytest.mark.skip(reason="Skipping as its causing other testcases to fail")
47+
def test_torch_compile_inference(monkeypatch, custom_working_directory):
48+
monkeypatch.syspath_prepend(EXAMPLE_ROOT_DIR)
49+
# Get the path to the custom working directory
50+
model_dir = custom_working_directory
51+
52+
try_and_handle(
53+
f"wget https://download.pytorch.org/models/{MODEL_PTH_FILE} -P {model_dir}"
54+
)
55+
56+
# Handler for Image classification
57+
handler = ImageClassifier()
58+
59+
# Context definition
60+
ctx = MockContext(
61+
model_pt_file=model_dir.joinpath(MODEL_PTH_FILE),
62+
model_dir=EXAMPLE_ROOT_DIR.as_posix(),
63+
model_file=MODEL_FILE,
64+
model_yaml_config_file=MODEL_YAML_CFG_FILE,
65+
)
66+
67+
torch.manual_seed(42 * 42)
68+
handler.initialize(ctx)
69+
handler.context = ctx
70+
handler.mapping = load_label_mapping(MAPPING_DATA)
71+
72+
data = {}
73+
with open(TEST_DATA, "rb") as image:
74+
image_file = image.read()
75+
byte_array_type = bytearray(image_file)
76+
data["body"] = byte_array_type
77+
78+
result = handler.handle([data], ctx)
79+
80+
labels = list(result[0].keys())
81+
82+
assert labels == EXPECTED_RESULTS

ts_scripts/spellcheck_conf/wordlist.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1175,3 +1175,4 @@ BabyLlamaHandler
11751175
CMakeLists
11761176
TorchScriptHandler
11771177
libllamacpp
1178+
warmup

0 commit comments

Comments
 (0)