Skip to content

Commit cfb4285

Browse files
agunapalmreso
andauthored
Update torch.export load with new api (#2906)
* changed to new api * Updated to use new api torch._export.aot_load * Updated to use new api torch._export.aot_load * update the install script * tested with CPU & batch size = 32 * updated based on review comments --------- Co-authored-by: Matthias Reso <[email protected]>
1 parent 1a567db commit cfb4285

File tree

6 files changed

+39
-45
lines changed

6 files changed

+39
-45
lines changed

examples/pt2/torch_export_aot_compile/install_pytorch_nightlies.sh

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
pip uninstall torchtext torchdata torch torchvision torchaudio -y
55

66
# Install nightly PyTorch and torchvision from the specified index URL
7-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed
7+
if nvidia-smi > /dev/null 2>&1; then
8+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed
9+
else
10+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu --ignore-installed
11+
fi
812

913
# Optional: Display the installed PyTorch and torchvision versions
1014
python -c "import torch; print('PyTorch version:', torch.__version__)"

examples/pt2/torch_export_aot_compile/resnet18_torch_export.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,26 @@
55

66
torch.set_float32_matmul_precision("high")
77

8+
MAX_BATCH_SIZE = 32
9+
810
model = resnet18(weights=ResNet18_Weights.DEFAULT)
911
model.eval()
1012

1113
with torch.no_grad():
12-
device = "cuda" if torch.cuda.is_available() else "cpu"
14+
if torch.cuda.is_available():
15+
device = "cuda"
16+
else:
17+
device = "cpu"
18+
# We need to turn off the below optimizations to support batch_size = 16,
19+
# which is treated like a special case
20+
# https://github.com/pytorch/pytorch/pull/116152
21+
torch.backends.mkldnn.set_flags(False)
22+
torch.backends.nnpack.set_flags(False)
23+
1324
model = model.to(device=device)
1425
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
1526

16-
# Max value is 15 because of https://github.com/pytorch/pytorch/pull/116152
17-
# On a CUDA enabled device, we tested batch_size of 32.
18-
batch_dim = torch.export.Dim("batch", min=2, max=15)
27+
batch_dim = torch.export.Dim("batch", min=2, max=MAX_BATCH_SIZE)
1928
so_path = torch._export.aot_compile(
2029
model,
2130
example_inputs,

test/pytest/test_torch_export.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml")
2020

2121

22-
PT_220_AVAILABLE = (
22+
PT_230_AVAILABLE = (
2323
True
24-
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1")
24+
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2")
2525
else False
2626
)
2727

@@ -30,6 +30,8 @@
3030
("kitten.jpg", EXPECTED_RESULTS[0]),
3131
]
3232

33+
BATCH_SIZE = 32
34+
3335

3436
import os
3537

@@ -47,7 +49,7 @@ def custom_working_directory(tmp_path):
4749
os.chdir(tmp_path)
4850

4951

50-
@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0")
52+
@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0")
5153
def test_torch_export_aot_compile(custom_working_directory):
5254
# Get the path to the custom working directory
5355
model_dir = custom_working_directory
@@ -88,7 +90,7 @@ def test_torch_export_aot_compile(custom_working_directory):
8890
assert labels == EXPECTED_RESULTS
8991

9092

91-
@pytest.mark.skipif(PT_220_AVAILABLE == False, reason="torch version is < 2.2.0")
93+
@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0")
9294
def test_torch_export_aot_compile_dynamic_batching(custom_working_directory):
9395
# Get the path to the custom working directory
9496
model_dir = custom_working_directory
@@ -122,7 +124,7 @@ def test_torch_export_aot_compile_dynamic_batching(custom_working_directory):
122124
byte_array_type = bytearray(image_file)
123125
data["body"] = byte_array_type
124126

125-
# Send a batch of 16 elements
126-
result = handler.handle([data for i in range(15)], ctx)
127+
# Send a batch of BATCH_SIZE elements
128+
result = handler.handle([data for i in range(BATCH_SIZE)], ctx)
127129

128-
assert len(result) == 15
130+
assert len(result) == BATCH_SIZE

ts/handler_utils/torch_export/__init__.py

Whitespace-only changes.

ts/handler_utils/torch_export/load_model.py

-27
This file was deleted.

ts/torch_handler/base_handler.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@
5353
)
5454
PT2_AVAILABLE = False
5555

56-
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"):
57-
PT220_AVAILABLE = True
56+
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2"):
57+
PT230_AVAILABLE = True
5858
else:
59-
PT220_AVAILABLE = False
59+
PT230_AVAILABLE = False
6060

6161
if os.environ.get("TS_IPEX_ENABLE", "false") == "true":
6262
try:
@@ -187,7 +187,7 @@ def initialize(self, context):
187187
elif (
188188
self.model_pt_path.endswith(".so")
189189
and self._use_torch_export_aot_compile()
190-
and PT220_AVAILABLE
190+
and PT230_AVAILABLE
191191
):
192192
# Set cuda device to the gpu_id of the backend worker
193193
# This is needed as the API for loading the exported model doesn't yet have a device id
@@ -256,9 +256,15 @@ def initialize(self, context):
256256
self.initialized = True
257257

258258
def _load_torch_export_aot_compile(self, model_so_path):
259-
from ts.handler_utils.torch_export.load_model import load_exported_model
259+
"""Loads the PyTorch model so and returns a Callable object.
260260
261-
return load_exported_model(model_so_path, self.map_location)
261+
Args:
262+
model_pt_path (str): denotes the path of the model file.
263+
264+
Returns:
265+
(Callable Object) : Loads the model object.
266+
"""
267+
return torch._export.aot_load(model_so_path, self.map_location)
262268

263269
def _load_torchscript_model(self, model_pt_path):
264270
"""Loads the PyTorch model and returns the NN model object.

0 commit comments

Comments
 (0)