-
Notifications
You must be signed in to change notification settings - Fork 878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compile ImageClassifier example #2915
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
7652c1e
updated torch.compile example
agunapal 6715c24
added pytest for example
agunapal 029f050
lint failure
agunapal 7f3ed78
show 3x speedup with torch.compile
agunapal 445e103
show 3x speedup with torch.compile
agunapal 8dd5e0e
show 3x speedup with torch.compile
agunapal 8e06ee0
added missing file
agunapal 4a45f7c
added missing file
agunapal a441b79
Merge branch 'master' into examples/torch_compile
agunapal 55b22d4
added sbin to path
agunapal 840038b
Merge branch 'examples/torch_compile' of https://github.com/pytorch/s…
agunapal 8a2699b
skipping test
agunapal a897763
Skipping pytest for now as its causing other tests to fail
agunapal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
|
||
# TorchServe inference with torch.compile of densenet161 model | ||
|
||
This example shows how to take eager model of `densenet161`, configure TorchServe to use `torch.compile` and run inference using `torch.compile` | ||
|
||
|
||
### Pre-requisites | ||
|
||
- `PyTorch >= 2.0` | ||
|
||
Change directory to the examples directory | ||
Ex: `cd examples/pt2/torch_compile` | ||
|
||
|
||
### torch.compile config | ||
|
||
`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html) | ||
|
||
In this example , we use the following config | ||
|
||
``` | ||
echo "pt2 : {backend: inductor, mode: reduce-overhead}" > model-config.yaml | ||
``` | ||
|
||
### Create model archive | ||
|
||
``` | ||
wget https://download.pytorch.org/models/densenet161-8d451a50.pth | ||
mkdir model_store | ||
torch-model-archiver --model-name densenet161 --version 1.0 --model-file model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler image_classifier --config-file model-config.yaml -f | ||
``` | ||
|
||
#### Start TorchServe | ||
``` | ||
torchserve --start --ncs --model-store model_store --models densenet161.mar | ||
``` | ||
|
||
#### Run Inference | ||
|
||
``` | ||
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg | ||
``` | ||
|
||
produces the output | ||
|
||
``` | ||
{ | ||
"tabby": 0.4664836823940277, | ||
"tiger_cat": 0.4645617604255676, | ||
"Egyptian_cat": 0.06619937717914581, | ||
"lynx": 0.0012969186063855886, | ||
"plastic_bag": 0.00022856894065625966 | ||
} | ||
``` | ||
|
||
### Performance improvement from using `torch.compile` | ||
|
||
To measure the handler `preprocess`, `inference`, `postprocess` times, run the following | ||
|
||
#### Measure inference time with PyTorch eager | ||
|
||
``` | ||
echo "handler:" > model-config.yaml && \ | ||
echo " profile: true" >> model-config.yaml | ||
``` | ||
|
||
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above. | ||
After a few iterations of warmup, we see the following | ||
|
||
``` | ||
2024-02-03T00:54:31,136 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:6.118656158447266|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS] | ||
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:18.77564811706543|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS] | ||
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.16630400717258453|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS] | ||
``` | ||
|
||
#### Measure inference time with `torch.compile` | ||
|
||
``` | ||
echo "pt2: {backend: inductor, mode: reduce-overhead}" > model-config.yaml && \ | ||
echo "handler:" >> model-config.yaml && \ | ||
echo " profile: true" >> model-config.yaml | ||
``` | ||
|
||
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above. | ||
`torch.compile` needs a few inferences to warmup. Once warmed up, we see the following | ||
``` | ||
2024-02-03T00:56:14,808 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:5.9771199226379395|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS] | ||
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:5.8818559646606445|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS] | ||
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.19392000138759613|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS] | ||
``` | ||
|
||
### Conclusion | ||
|
||
`torch.compile` reduces the inference time from 18ms to 5ms |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pt2 : {backend: inductor, mode: reduce-overhead} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from torchvision.models.densenet import DenseNet | ||
|
||
|
||
class ImageClassifier(DenseNet): | ||
def __init__(self): | ||
super(ImageClassifier, self).__init__(48, (6, 12, 36, 24), 96) | ||
|
||
def load_state_dict(self, state_dict, strict=True): | ||
# '.'s are no longer allowed in module names, but previous _DenseLayer | ||
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. | ||
# They are also in the checkpoints in model_urls. This pattern is used | ||
# to find such keys. | ||
# Credit - https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py#def _load_state_dict() | ||
import re | ||
|
||
pattern = re.compile( | ||
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" | ||
) | ||
|
||
for key in list(state_dict.keys()): | ||
res = pattern.match(key) | ||
if res: | ||
new_key = res.group(1) + res.group(2) | ||
state_dict[new_key] = state_dict[key] | ||
del state_dict[key] | ||
|
||
return super(ImageClassifier, self).load_state_dict(state_dict, strict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
from pkg_resources import packaging | ||
|
||
from ts.torch_handler.image_classifier import ImageClassifier | ||
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext | ||
from ts.utils.util import load_label_mapping | ||
from ts_scripts.utils import try_and_handle | ||
|
||
CURR_FILE_PATH = Path(__file__).parent.absolute() | ||
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1] | ||
EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_compile") | ||
TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg") | ||
MAPPING_DATA = REPO_ROOT_DIR.joinpath( | ||
"examples", "image_classifier", "index_to_name.json" | ||
) | ||
MODEL_PTH_FILE = "densenet161-8d451a50.pth" | ||
MODEL_FILE = "model.py" | ||
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml") | ||
|
||
|
||
PT2_AVAILABLE = ( | ||
True | ||
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.0") | ||
else False | ||
) | ||
|
||
EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "plastic_bag"] | ||
|
||
|
||
@pytest.fixture | ||
def custom_working_directory(tmp_path): | ||
# Set the custom working directory | ||
custom_dir = tmp_path / "model_dir" | ||
custom_dir.mkdir() | ||
os.chdir(custom_dir) | ||
yield custom_dir | ||
# Clean up and return to the original working directory | ||
os.chdir(tmp_path) | ||
|
||
|
||
@pytest.mark.skipif(PT2_AVAILABLE == False, reason="torch version is < 2.0") | ||
@pytest.mark.skip(reason="Skipping as its causing other testcases to fail") | ||
def test_torch_compile_inference(monkeypatch, custom_working_directory): | ||
monkeypatch.syspath_prepend(EXAMPLE_ROOT_DIR) | ||
# Get the path to the custom working directory | ||
model_dir = custom_working_directory | ||
|
||
try_and_handle( | ||
f"wget https://download.pytorch.org/models/{MODEL_PTH_FILE} -P {model_dir}" | ||
) | ||
|
||
# Handler for Image classification | ||
handler = ImageClassifier() | ||
|
||
# Context definition | ||
ctx = MockContext( | ||
model_pt_file=model_dir.joinpath(MODEL_PTH_FILE), | ||
model_dir=EXAMPLE_ROOT_DIR.as_posix(), | ||
model_file=MODEL_FILE, | ||
model_yaml_config_file=MODEL_YAML_CFG_FILE, | ||
) | ||
|
||
torch.manual_seed(42 * 42) | ||
handler.initialize(ctx) | ||
handler.context = ctx | ||
handler.mapping = load_label_mapping(MAPPING_DATA) | ||
|
||
data = {} | ||
with open(TEST_DATA, "rb") as image: | ||
image_file = image.read() | ||
byte_array_type = bytearray(image_file) | ||
data["body"] = byte_array_type | ||
|
||
result = handler.handle([data], ctx) | ||
|
||
labels = list(result[0].keys()) | ||
|
||
assert labels == EXPECTED_RESULTS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1175,3 +1175,4 @@ BabyLlamaHandler | |
CMakeLists | ||
TorchScriptHandler | ||
libllamacpp | ||
warmup |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so people would be integrating torch.compile to see speedups
I also believe the label values change a tiny bit so you should show the expected speedups
I mention this because I'm not sure if densenet speedups were there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great feedback. I added a section on perf measurement. 3x speedup with torch.compile. Makes it very compelling now!