Skip to content

Commit 6407c93

Browse files
suryasiddravi9ynimmagaagunapal
authored
Openvino integration into torchserve (#3116)
* Initial commit for OpenVINO integration * Added new examples folder for torch_compile_openvino * Update with readme * Update README.md * Added stable diffusion example * Added README for stable diffusion * Generalized use of pt2 in model_config.yaml * Changed DiffusionPipeline to StableDiffusionXL * Updated README for StableDiffusion * Added is_xl parameter in model_config * Changed the hello world example to vgg16 * Updated README file with more description of options * Changed model to resnet50 * Added pytorch 2.1.0 requirement * Removed openvino from setup.py and added in requirements * Added requirements.txt for resnet50 example * Fixed linter issues * Fixed linter issues in handler * Update Readme * Update Readmes * Update readme --------- Co-authored-by: Ravi Panchumarthy <[email protected]> Co-authored-by: ynimmaga <[email protected]> Co-authored-by: Ankith Gunapal <[email protected]>
1 parent ffe1ed2 commit 6407c93

15 files changed

+642
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
2+
# TorchServe Inference with torch.compile with OpenVINO backend of Resnet50 model
3+
4+
This guide provides steps on how to optimize a ResNet50 model using `torch.compile` with [OpenVINO backend](https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html), aiming to enhance inference performance when deployed through TorchServe. `torch.compile` allows for ahead-of-time compilation of PyTorch models, and when combined with OpenVINO, it leverages hardware optimizations that are particularly beneficial for deployment in production environments.
5+
6+
### Prerequisites
7+
- `PyTorch >= 2.1.0`
8+
- `OpenVINO >= 2024.1.0` . Install the latest version as shown below:
9+
10+
```bash
11+
# Install OpenVINO
12+
cd examples/pt2/torch_compile_openvino
13+
pip install -r requirements.txt
14+
```
15+
16+
## Workflow
17+
1. Configure torch.compile.
18+
1. Create Model Archive.
19+
1. Start TorchServe.
20+
1. Run Inference.
21+
1. Stop TorchServe.
22+
1. Measure and Compare Performance with different backends.
23+
24+
First, navigate to `examples/pt2/torch_compile_openvino`
25+
```bash
26+
cd examples/pt2/torch_compile_openvino
27+
```
28+
29+
### 1. Configure torch.compile
30+
31+
`torch.compile` allows various configurations that can influence performance outcomes. Explore different options in the [official PyTorch documentation](https://pytorch.org/docs/stable/generated/torch.compile.html) and the [OpenVINO backend documentation](https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html).
32+
33+
34+
In this example, we use the following config:
35+
36+
```bash
37+
echo "minWorkers: 1
38+
maxWorkers: 2
39+
pt2: {backend: openvino}" > model-config.yaml
40+
```
41+
42+
If you want to measure the handler `preprocess`, `inference`, `postprocess` times, use the following config:
43+
44+
```bash
45+
echo "minWorkers: 1
46+
maxWorkers: 2
47+
pt2: {backend: openvino}
48+
handler:
49+
profile: true" > model-config.yaml
50+
```
51+
52+
### 2. Create model archive
53+
54+
Download the pre-trained model and prepare the model archive:
55+
```bash
56+
wget https://download.pytorch.org/models/resnet50-11ad3fa6.pth
57+
mkdir model_store
58+
torch-model-archiver --model-name resnet-50 --version 1.0 --model-file model.py \
59+
--serialized-file resnet50-11ad3fa6.pth --export-path model_store \
60+
--extra-files ../../image_classifier/index_to_name.json --handler image_classifier \
61+
--config-file model-config.yaml
62+
```
63+
64+
### 3. Start TorchServe
65+
66+
Start the TorchServe server using the following command:
67+
```bash
68+
torchserve --start --ncs --model-store model_store --models resnet-50.mar
69+
```
70+
71+
### 4. Run Inference
72+
73+
**Note:** `torch.compile` requires a warm-up phase to reach optimal performance. Ensure you run at least as many inferences as the `maxWorkers` specified before measuring performance.
74+
75+
```bash
76+
# Open a new terminal
77+
cd examples/pt2/torch_compile_openvino
78+
curl http://127.0.0.1:8080/predictions/resnet-50 -T ../../image_classifier/kitten.jpg
79+
```
80+
81+
The expected output will be JSON-formatted classification probabilities, such as:
82+
83+
```bash
84+
{
85+
"tabby": 0.27249985933303833,
86+
"tiger_cat": 0.13740447163581848,
87+
"Egyptian_cat": 0.04627467691898346,
88+
"lynx": 0.0032067003194242716,
89+
"lens_cap": 0.002257897751405835
90+
}
91+
```
92+
93+
### 5. Stop the server
94+
Stop TorchServe with the following command:
95+
96+
```bash
97+
torchserve --stop
98+
```
99+
100+
### 6. Measure and Compare Performance with different backends
101+
102+
Following the steps outlined in the previous section, you can compare the inference times for Eager mode, Inductor backend, and OpenVINO backend:
103+
104+
1. Update model-config.yaml by adding `profile: true` under the `handler` section.
105+
1. Create a new model archive using torch-model-archiver with the updated configuration.
106+
1. Start TorchServe and run inference.
107+
1. Analyze the TorchServe logs for metrics like `ts_handler_preprocess.Milliseconds`, `ts_handler_inference.Milliseconds`, and `ts_handler_postprocess.Milliseconds`. These metrics represent the time taken for pre-processing, inference, and post-processing steps, respectively, for each inference request.
108+
109+
#### 6.1. Measure inference time with Pytorch Eager mode
110+
111+
Update the `model-config.yaml` file to use Pytorch Eager mode:
112+
113+
```bash
114+
echo "minWorkers: 1
115+
maxWorkers: 2
116+
handler:
117+
profile: true" > model-config.yaml
118+
```
119+
120+
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
121+
122+
```bash
123+
2024-05-01T10:29:29,586 [INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:5.254030227661133|#ModelName:resnet-50,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714559369,fd3743e0-9c89-41b2-9972-c1f403872113, pattern=[METRICS]
124+
2024-05-01T10:29:29,609 [INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:22.122859954833984|#ModelName:resnet-50,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714559369,fd3743e0-9c89-41b2-9972-c1f403872113, pattern=[METRICS]
125+
2024-05-01T10:29:29,609 [INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.057220458984375|#ModelName:resnet-50,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714559369,fd3743e0-9c89-41b2-9972-c1f403872113, pattern=[METRICS]
126+
```
127+
128+
#### 6.2. Measure inference time with using `torch.compile` with backend Inductor
129+
130+
Update the model-config.yaml file to specify the Inductor backend:
131+
132+
```bash
133+
echo "minWorkers: 1
134+
maxWorkers: 2
135+
pt2: {backend: inductor, mode: reduce-overhead}
136+
handler:
137+
profile: true" > model-config.yaml
138+
```
139+
140+
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
141+
`torch.compile` requires a warm-up phase to reach optimal performance. Ensure you run at least as many inferences as the `maxWorkers` specified before measuring performance.
142+
143+
```bash
144+
2024-05-01T10:32:05,808 [INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:5.209445953369141|#ModelName:resnet-50,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714559525,9f84ea11-7b77-40e3-bf2c-926746db9c6f, pattern=[METRICS]
145+
2024-05-01T10:32:05,821 [INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:12.910842895507812|#ModelName:resnet-50,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714559525,9f84ea11-7b77-40e3-bf2c-926746db9c6f, pattern=[METRICS]
146+
2024-05-01T10:32:05,822 [INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.06079673767089844|#ModelName:resnet-50,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714559525,9f84ea11-7b77-40e3-bf2c-926746db9c6f, pattern=[METRICS]
147+
```
148+
149+
#### 6.3. Measure inference time with using `torch.compile` with backend OpenVINO
150+
151+
Update the model-config.yaml file to specify the OpenVINO backend:
152+
153+
```bash
154+
echo "minWorkers: 1
155+
maxWorkers: 2
156+
pt2: {backend: openvino}
157+
handler:
158+
profile: true" > model-config.yaml
159+
```
160+
161+
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
162+
`torch.compile` requires a warm-up phase to reach optimal performance. Ensure you run at least as many inferences as the `maxWorkers` specified before measuring performance.
163+
164+
```bash
165+
2024-05-01T10:40:45,031 [INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:5.637407302856445|#ModelName:resnet-50,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714560045,7fffdb96-7022-495d-95bb-8dd0b17bf30a, pattern=[METRICS]
166+
2024-05-01T10:40:45,036 [INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:5.518198013305664|#ModelName:resnet-50,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714560045,7fffdb96-7022-495d-95bb-8dd0b17bf30a, pattern=[METRICS]
167+
2024-05-01T10:40:45,037 [INFO ] W-9000-resnet-50_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.06508827209472656|#ModelName:resnet-50,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714560045,7fffdb96-7022-495d-95bb-8dd0b17bf30a, pattern=[METRICS]
168+
```
169+
170+
### Conclusion
171+
172+
- Using `torch.compile` with the OpenVINO backend, inference times are reduced to approximately 5.5 ms, a significant improvement from 22 ms with the Eager backend and 13 ms with the Inductor backend. This configuration has been tested on an Intel Xeon Platinum 8469 CPU, showing substantial enhancements in processing speed.
173+
174+
- The actual performance gains may vary depending on your hardware, model complexity, and workload. Consider exploring more advanced `torch.compile` [configurations](https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html) for further optimization based on your specific use case.
175+
176+
- Try out [Stable Diffusion](./stable_diffusion/) example !
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
minWorkers: 1
2+
maxWorkers: 2
3+
pt2 : {backend: "openvino"}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from torchvision.models.resnet import Bottleneck, ResNet
2+
3+
4+
class ImageClassifier(ResNet):
5+
def __init__(self):
6+
super(ImageClassifier, self).__init__(Bottleneck, [3, 4, 6, 3])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
openvino>=2024.1.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
from diffusers import DiffusionPipeline
3+
4+
pipeline = DiffusionPipeline.from_pretrained(
5+
"stabilityai/stable-diffusion-xl-base-1.0",
6+
torch_dtype=torch.float32,
7+
use_safetensors=True,
8+
)
9+
pipeline.save_pretrained("./Base_Diffusion_model")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
2+
# Accelerating StableDiffusionXL model with torch.compile OpenVINO backend
3+
4+
[Stable Diffusion XL ](https://huggingface.co/docs/diffusers/en/using-diffusers/sdxl) is a image generation model that is geared towards generating more photorealistic images to its predecessors. This guide details the process of enhancing model performance using the torch.compile with the OpenVINO backend, specifically tested on Intel Xeon Platinum 8469 CPU and Intel GPU Flex 170.
5+
6+
7+
### Prerequisites
8+
- `PyTorch >= 2.1.0`
9+
- `OpenVINO >= 2024.1.0` . Install the latest version as shown below:
10+
11+
```bash
12+
cd examples/pt2/torch_compile_openvino/stable_diffusion
13+
pip install -r requirements.txt
14+
```
15+
16+
## Workflow
17+
1. Configure torch.compile.
18+
1. Create Model Archive.
19+
1. Start TorchServe.
20+
1. Run Inference.
21+
1. Stop TorchServe.
22+
1. Measure and Compare Performance with different backends and devices.
23+
24+
### 1. Configure torch.compile
25+
26+
`torch.compile` allows various configurations that can influence performance outcomes. Explore different options in the [official PyTorch documentation](https://pytorch.org/docs/stable/generated/torch.compile.html) and the [OpenVINO backend documentation](https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html).
27+
28+
In this example, we utilize the configuration defined in [model-config.yaml](./model-config.yaml). Specifically, the OpenVINO backend is enabled to optimize performance, as shown in the configuration snippet:
29+
```yaml
30+
pt2: {backend: 'openvino'}
31+
```
32+
33+
#### Additional Configuration Options:
34+
- If you want to measure the handler `preprocess`, `inference`, `postprocess` times, include `profile: true` in the handler section of the config:
35+
36+
```bash
37+
echo " profile: true" >> model-config.yaml
38+
```
39+
40+
- `torch.compile` OpenVINO backend supports additional configurations for model caching, device selection, and other OpenVINO specific options. Refer to the [torch.compile OpenVINO options documentation](https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html#options). For example, see the following configuration sample:
41+
42+
```yaml
43+
pt2: {backend: 'openvino', options: {'model_caching' : True, 'device': 'GPU'}}
44+
```
45+
46+
- `model_caching`: Enables caching the model after the initial run, reducing the first inference latency for subsequent runs of the same model.
47+
- `device`: Specifies the hardware device for running the application.
48+
49+
### 2. Create Model Archive
50+
51+
Download Stable diffusion model and prepare the Model Archive using [model-config.yaml](./model-config.yaml) configuration.
52+
53+
```bash
54+
# Download the Stable diffusion model. Saves it to the Base_Diffusion_model directory.
55+
python Download_model.py
56+
57+
# Create model archive
58+
torch-model-archiver --model-name diffusion_fast --version 1.0 --handler stable_diffusion_handler.py \
59+
--config-file model-config.yaml --extra-files "./pipeline_utils.py" --archive-format no-archive
60+
61+
mv Base_Diffusion_model diffusion_fast/
62+
63+
# Add the model archive to model store
64+
mkdir model_store
65+
mv diffusion_fast model_store
66+
```
67+
68+
### 3. Start torchserve
69+
70+
Start the TorchServe server using the following command:
71+
72+
```bash
73+
torchserve --start --ts-config config.properties --model-store model_store --models diffusion_fast
74+
```
75+
76+
### 4. Run inference
77+
78+
Execute the model using the following command to generate an image based on your specified prompt:
79+
80+
```bash
81+
python query.py --url "http://localhost:8080/predictions/diffusion_fast" --prompt "a photo of an astronaut riding a horse on mars"
82+
```
83+
84+
By default, the generated image is saved to a file named `output-<timestamp>.jpg`. You can customize the output filename by using the `--filename` parameter in the `query.py` script.
85+
86+
87+
### 5. Stop the server
88+
Stop TorchServe with the following command:
89+
90+
```bash
91+
torchserve --stop
92+
```
93+
94+
### 6. Measure and Compare Performance with different backends
95+
96+
Following the steps outlined in the previous section, you can compare the inference times for Inductor backend and OpenVINO backend:
97+
98+
1. Update model-config.yaml by adding `profile: true` under the `handler` section.
99+
1. Create a new model archive using torch-model-archiver with the updated configuration.
100+
1. Start TorchServe and run inference.
101+
1. Analyze the TorchServe logs for metrics like `ts_handler_preprocess.Milliseconds`, `ts_handler_inference.Milliseconds`, and `ts_handler_postprocess.Milliseconds`. These metrics represent the time taken for pre-processing, inference, and post-processing steps, respectively, for each inference request.
102+
103+
104+
#### 6.1. Measure inference time with `torch.compile` Inductor backend
105+
106+
Update the `model-config.yaml` file to specify the Inductor backend:
107+
108+
```yaml
109+
pt2: {backend: inductor, mode: reduce-overhead}
110+
```
111+
Make sure that profiling is enabled
112+
113+
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
114+
`torch.compile` requires a warm-up phase to reach optimal performance. Ensure you run at least as many inferences as the `maxWorkers` specified before measuring performance.
115+
116+
After a few iterations of warmup, we see the following
117+
118+
```bash
119+
2024-04-25T07:21:31,722 [INFO ] W-9000-diffusion_fast_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:0.0054836273193359375|#ModelName:diffusion_fast,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714029691,10ca6d02-5895-4af3-a052-6b5d409ca676, pattern=[METRICS]
120+
2024-04-25T07:22:31,405 [INFO ] W-9000-diffusion_fast_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:59682.70015716553|#ModelName:diffusion_fast,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714029751,10ca6d02-5895-4af3-a052-6b5d409ca676, pattern=[METRICS]
121+
2024-04-25T07:22:31,947 [INFO ] W-9000-diffusion_fast_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:542.2341823577881|#ModelName:diffusion_fast,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714029751,10ca6d02-5895-4af3-a052-6b5d409ca676, pattern=[METRICS]
122+
```
123+
124+
#### 6.2. Measure inference time with `torch.compile` OpenVINO backend
125+
126+
Update the `model-config.yaml` file to specify the OpenVINO backend:
127+
128+
```yaml
129+
pt2: {backend: openvino}
130+
```
131+
132+
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
133+
`torch.compile` requires a warm-up phase to reach optimal performance. Ensure you run at least as many inferences as the `maxWorkers` specified before measuring performance.
134+
135+
After a few iterations of warmup, we see the following:
136+
137+
```bash
138+
2024-04-25T07:12:36,276 [INFO ] W-9000-diffusion_fast_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:0.0045299530029296875|#ModelName:diffusion_fast,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714029156,2d8c54ac-1c6f-43d7-93b0-bb205a9a06ee, pattern=[METRICS]
139+
2024-04-25T07:12:51,667 [INFO ] W-9000-diffusion_fast_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:15391.06822013855|#ModelName:diffusion_fast,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714029171,2d8c54ac-1c6f-43d7-93b0-bb205a9a06ee, pattern=[METRICS]
140+
2024-04-25T07:12:51,955 [INFO ] W-9000-diffusion_fast_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:287.31536865234375|#ModelName:diffusion_fast,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714029171,2d8c54ac-1c6f-43d7-93b0-bb205a9a06ee, pattern=[METRICS]
141+
```
142+
143+
#### 6.3. Measure inference time with `torch.compile` OpenVINO backend on Intel Discrete GPU
144+
145+
Update the `model-config.yaml` file to specify the OpenVINO backend and with Intel GPU device:
146+
147+
```yaml
148+
pt2: {backend: 'openvino', options: {'model_caching' : True, 'device': 'GPU'}}
149+
```
150+
151+
Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
152+
`torch.compile` requires a warm-up phase to reach optimal performance. Ensure you run at least as many inferences as the `maxWorkers` specified before measuring performance.
153+
154+
After a few iterations of warmup, we see the following:
155+
156+
```bash
157+
2024-04-25T07:28:32,662 [INFO ] W-9000-diffusion_fast_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:0.0050067901611328125|#ModelName:diffusion_fast,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714030112,579edbf3-5d78-40aa-b49c-480796b4d3b1, pattern=[METRICS]
158+
2024-04-25T07:28:39,887 [INFO ] W-9000-diffusion_fast_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:7225.085020065308|#ModelName:diffusion_fast,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714030119,579edbf3-5d78-40aa-b49c-480796b4d3b1, pattern=[METRICS]
159+
2024-04-25T07:28:40,174 [INFO ] W-9000-diffusion_fast_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:286.96274757385254|#ModelName:diffusion_fast,Level:Model|#type:GAUGE|#hostname:MDSATSM002ARC,1714030120,579edbf3-5d78-40aa-b49c-480796b4d3b1, pattern=[METRICS]
160+
```
161+
162+
### Conclusion
163+
164+
Using `torch.compile` with the OpenVINO backend significantly enhances the performance of the StableDiffusionXL model. When comparing backends:
165+
166+
- Using the **Inductor backend**, the inference time on a CPU (Intel Xeon Platinum 8469) is around 59 seconds.
167+
- Switching to the **OpenVINO backend** on the same CPU (Intel Xeon Platinum 8469) reduces the inference time to approximately 15 seconds.
168+
- Furthermore, employing an Intel Discrete GPU (Intel GPU Flex 170) with the OpenVINO backend reduces the inference time even more dramatically, to about 7 seconds.
169+
170+
The actual performance gains may vary depending on your hardware, model complexity, and workload. Consider exploring more advanced `torch.compile` [configurations](https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html) for further optimization based on your specific use case.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
inference_address=http://127.0.0.1:8080
2+
management_address=http://127.0.0.1:8081
3+
metrics_address=http://127.0.0.1:8082
4+
max_response_size=655350000

0 commit comments

Comments
 (0)