Skip to content

Commit cbb2363

Browse files
committed
added batch inferencing example
1 parent dddf4d4 commit cbb2363

File tree

8 files changed

+313
-2
lines changed

8 files changed

+313
-2
lines changed

docs/batch_inference_with_ts.md

+120-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Contents of this Document
44
* [Introduction](#introduction)
5+
56
* [Conclusion](#conclusion)
67

78
## Introduction
@@ -16,14 +17,131 @@ Before jumping into this document, please go over the following docs
1617
1. [What is TorchServe?](../README.md)
1718
1. [What is custom service code?](custom_service.md)
1819

19-
## Batch Inference with TorchServe
20+
## Batch Inference with TorchServe using ResNet-152 model
2021
To support batching of inference requests, TorchServe needs the following:
2122
1. TorchServe Model Configuration: TorchServe provides means to configure "Max Batch Size" and "Max Batch Delay" through "POST /models" API.
2223
TorchServe needs to know the maximum batch size that the model can handle and the maximum delay that TorchServe should wait for, to form this request-batch.
2324
2. Model Handler code: TorchServe requires the Model Handler to handle the batch of inference requests.
2425

25-
## TODO : Add detailed example with pytorch model.
26+
For a full working code of a custom model handler with batch processing, refer to [resnet152_handler.py](../examples/image_classifier/resnet_152_batch/resnet152_handler.py)
27+
28+
### TorchServe Model Configuration
29+
To configure TorchServe to use the batching feature, you would have to provide the batch configuration information through [**POST /models** API](management_api.md#register-a-model).
30+
The configuration that we are interested in is the following:
31+
1. `batch_size`: This is the maximum batch size that a model is expected to handle.
32+
2. `max_batch_delay`: This is the maximum batch delay time TorchServe waits to receive `batch_size` number of requests. If TorchServe doesn't receive `batch_size` number of requests
33+
before this timer time's out, it sends what ever requests that were received to the model `handler`.
34+
35+
Let's look at an example using this configuration
36+
```bash
37+
# The following command will register a model "resnet-152.mar" and configure TorchServe to use a batch_size of 8 and a max batch delay of 50 milli seconds.
38+
curl -X POST "localhost:8081/models?url=resnet-152.mar&batch_size=8&max_batch_delay=50"
39+
```
40+
41+
These configurations are used both in TorchServe and in the model's custom-service-code (a.k.a the handler code). TorchServe associates the batch related configuration with each model. The frontend then tries to aggregate the batch-size number of requests and send it to the backend.
42+
43+
## Demo to configure TorchServe with batch-supported model
44+
In this section lets bring up model server and launch Resnet-152 model, which has been built to handle a batch of request.
45+
46+
### Pre-requisites
47+
Follow the main [Readme](../README.md) and install all the required packages including "torchserve"
48+
49+
### Loading Resnet-152 which handles batch inferences
50+
* Start the model server. In this example, we are starting the model server to run on inference port 8080 and management port 8081.
51+
```text
52+
$ cat config.properties
53+
...
54+
inference_address=http://0.0.0.0:8080
55+
management_address=http://0.0.0.0:8081
56+
...
57+
$ torchserve --start --model-store model_store
58+
```
59+
60+
Note : This example assumes that the resnet-152.mar file is available in the torchserve model_store. For more details on creating resnet-152 mar file and serving it on TorchServe refer [resnet152 image classification example](../examples/image_classifier/resnet_152_batch/README.md)
61+
62+
* Verify that the TorchServe is up and running
63+
```text
64+
$ curl localhost:8080/ping
65+
{
66+
"status": "Healthy"
67+
}
68+
```
69+
70+
* Now lets launch resnet-152 model, which we have built to handle batch inference. Since this is an example, we are going to launch 1 worker which handles a batch size of 8
71+
with a max-batch-delay of 10ms.
72+
```text
73+
$ curl -X POST "localhost:8081/models?url=resnet-152.mar&batch_size=8&max_batch_delay=10&initial_workers=1"
74+
{
75+
"status": "Processing worker updates..."
76+
}
77+
```
78+
79+
* Verify that the workers were started properly
80+
```text
81+
$ curl localhost:8081/models/resnet-152
82+
{
83+
"modelName": "resnet-152",
84+
"modelUrl": "https://s3.amazonaws.com/model-server/model_archive_1.0/examples/resnet-152-batching/resnet-152.mar",
85+
"runtime": "python",
86+
"minWorkers": 1,
87+
"maxWorkers": 1,
88+
"batchSize": 8,
89+
"maxBatchDelay": 10,
90+
"workers": [
91+
{
92+
"id": "9008",
93+
"startTime": "2019-02-19T23:56:33.907Z",
94+
"status": "READY",
95+
"gpu": false,
96+
"memoryUsage": 607715328
97+
}
98+
]
99+
}
100+
```
26101

102+
* Now let's test this service.
103+
* Get an image to test this service
104+
```text
105+
$ curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg
106+
```
107+
* Run inference to test the model
108+
```text
109+
$ curl -X POST localhost/predictions/resnet-152 -T kitten.jpg
110+
{
111+
"probability": 0.7148938179016113,
112+
"class": "n02123045 tabby, tabby cat"
113+
},
114+
{
115+
"probability": 0.22877725958824158,
116+
"class": "n02123159 tiger cat"
117+
},
118+
{
119+
"probability": 0.04032370448112488,
120+
"class": "n02124075 Egyptian cat"
121+
},
122+
{
123+
"probability": 0.00837081391364336,
124+
"class": "n02127052 lynx, catamount"
125+
},
126+
{
127+
"probability": 0.0006728120497427881,
128+
"class": "n02129604 tiger, Panthera tigris"
129+
}
130+
```
131+
132+
* Now that we have the service up and running, we could run performance tests with the same kitten image as follows. There are multiple tools to measure performance of web-servers. We will use
133+
[apache-bench](https://httpd.apache.org/docs/2.4/programs/ab.html) to run our performance tests. We chose `apache-bench` for our tests because of the ease of installation and ease of running tests.
134+
Before running this test, we need to first install `apache-bench` on our System. Since we were running this on a ubuntu host, we installed apache-bench as follows
135+
```bash
136+
$ sudo apt-get udpate && sudo apt-get install apache2-utils
137+
```
138+
Now that installation is done, we can run performance benchmark test as follows.
139+
```text
140+
$ ab -k -l -n 10000 -c 1000 -T "image/jpeg" -p kitten.jpg localhost:8080/predictions/resnet-152
141+
```
142+
The above test simulates TorchServe receiving 1000 concurrent requests at once and a total of 10,000 requests. All of these requests are directed to the endpoint "localhost:8080/predictions/resnet-152", which assumes
143+
that resnet-152 is already registered and scaled-up on TorchServe. We had done this registration and scaling up in the above steps.
144+
27145
## Conclusion
28146
The take away from the experiments is that batching is a very useful feature. In cases where the services receive heavy load of requests or each request has high I/O, its advantageous
29147
to batch the requests. This allows for maximally utilizing the compute resources, especially GPU compute which are also more often than not more expensive. But customers should
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#### Sample commands to create a resnet-152 eager mode model archive for batch inputs, register it on TorchServe and run image prediction
2+
3+
```bash
4+
wget https://download.pytorch.org/models/resnet152-b121ed2d.pth
5+
torch-model-archiver --model-name resnet-152-batch --version 1.0 --model-file serve/examples/image_classifier/resnet_152_batch/model.py --serialized-file resnet152-b121ed2d.pth --handler serve/examples/image_classifier/resnet_152_batch/resnet152_handler.py --extra-files serve/examples/image_classifier/index_to_name.json
6+
mkdir model-store
7+
mv resnet-152-batch.mar model-store/
8+
torchserve --start --model-store model-store
9+
curl -X POST curl -X POST "localhost:8081/models?model_name=resnet152&url=resnet-152-batch.mar&batch_size=4&max_batch_delay=5000&initial_workers=3&synchronous=true"
10+
```
11+
12+
The above commands will create the mar file and register the resnet152 model with torchserve with following configuration :
13+
14+
- model_name : resnet152
15+
- batch_size : 4
16+
- max_batch_delay : 5000 ms
17+
- workers : 3
18+
19+
To test batch inference execute the following commands within the specified max_batch_delay time :
20+
21+
```bash
22+
curl -X POST http://127.0.0.1:8080/predictions/resnet152 -T serve/examples/image_classifier/resnet_152_batch/images/croco.jpg &
23+
curl -X POST http://127.0.0.1:8080/predictions/resnet152 -T serve/examples/image_classifier/resnet_152_batch/images/dog.jpg &
24+
curl -X POST http://127.0.0.1:8080/predictions/resnet152 -T serve/examples/image_classifier/resnet_152_batch/images/kitten.jpg &
25+
```
26+
27+
#### TorchScript example using Resnet152 image classifier:
28+
29+
* Save the Resnet152-batch model in as an executable script module or a traced script:
30+
31+
1. Save model using scripting
32+
```python
33+
#scripted mode
34+
from torchvision import models
35+
import torch
36+
model = models.resnet152(pretrained=True)
37+
sm = torch.jit.script(model)
38+
sm.save("resnet-152-batch.pt")
39+
```
40+
41+
2. Save model using tracing
42+
```python
43+
#traced mode
44+
from torchvision import models
45+
import torch
46+
model = models.resnet152(pretrained=True)
47+
example_input = torch.rand(1, 3, 224, 224)
48+
traced_script_module = torch.jit.trace(model, example_input)
49+
traced_script_module.save("resnet-152-batch.pt")
50+
```
51+
52+
* Use following commands to register Resnet152-batch torchscript model on TorchServe and run image prediction
53+
54+
```bash
55+
torch-model-archiver --model-name resnet-152-batch --version 1.0 --serialized-file resnet-152-batch.pt --extra-files serve/examples/image_classifier/index_to_name.json --handler image_classifier
56+
mkdir model-store
57+
mv resnet-152-batch.mar model-store/
58+
torchserve --start --model-store model-store --models resnet-152-batch=resnet-152-batch.mar
59+
curl -X POST http://127.0.0.1:8080/predictions/resnet-152-batch -T serve/examples/image_classifier/kitten.jpg
60+
```
Loading
Loading
Loading

examples/image_classifier/resnet_152_batch/index_to_name.json

+1
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from torchvision.models.resnet import ResNet, Bottleneck
2+
3+
4+
class RestNet152ImageClassifier(ResNet):
5+
def __init__(self):
6+
super(RestNet152ImageClassifier, self).__init__(Bottleneck, [3, 8, 36, 3])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import io
2+
import logging
3+
import numpy as np
4+
import os
5+
import torch
6+
from PIL import Image
7+
from torch.autograd import Variable
8+
from torchvision import transforms
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class BatchImageClassifier(object):
14+
"""
15+
BatchImageClassifier handler class. This handler takes list of images
16+
and returns a corresponding list of classes
17+
"""
18+
19+
def __init__(self):
20+
self.model = None
21+
self.mapping = None
22+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23+
self.initialized = False
24+
25+
def initialize(self, context):
26+
"""First try to load torchscript else load eager mode state_dict based model"""
27+
28+
self.manifest = context.manifest
29+
properties = context.system_properties
30+
model_dir = properties.get("model_dir")
31+
32+
# Read model serialize/pt file
33+
serialized_file = self.manifest['model']['serializedFile']
34+
model_pt_path = os.path.join(model_dir, serialized_file)
35+
if not os.path.isfile(model_pt_path):
36+
raise RuntimeError("Missing the model.pt file")
37+
38+
try:
39+
logger.debug('Loading torchscript model')
40+
self.model = torch.jit.load(model_pt_path)
41+
except Exception as e:
42+
# Read model definition file
43+
model_file = self.manifest['model']['modelFile']
44+
model_def_path = os.path.join(model_dir, model_file)
45+
if not os.path.isfile(model_def_path):
46+
raise RuntimeError("Missing the model.py file")
47+
48+
state_dict = torch.load(model_pt_path, map_location=self.device)
49+
from model import RestNet152ImageClassifier
50+
self.model = RestNet152ImageClassifier()
51+
self.model.load_state_dict(state_dict)
52+
53+
self.model.eval()
54+
logger.debug('Model file {0} loaded successfully'.format(model_pt_path))
55+
56+
# Read the mapping file, index to object name
57+
mapping_file_path = os.path.join(model_dir, "index_to_name.json")
58+
import json
59+
if os.path.isfile(mapping_file_path):
60+
with open(mapping_file_path) as f:
61+
self.mapping = json.load(f)
62+
else:
63+
logger.warning('Missing the index_to_name.json file. Inference output will not include class name.')
64+
65+
self.initialized = True
66+
67+
def preprocess(self, request):
68+
"""
69+
Scales, crops, and normalizes a PIL image for a PyTorch model,
70+
returns an Numpy array
71+
"""
72+
73+
image_tensor = None
74+
75+
for idx, data in enumerate(request):
76+
image = data.get("data")
77+
if image is None:
78+
image = data.get("body")
79+
80+
my_preprocess = transforms.Compose([
81+
transforms.Resize(256),
82+
transforms.CenterCrop(224),
83+
transforms.ToTensor(),
84+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
85+
std=[0.229, 0.224, 0.225])
86+
])
87+
input_image = Image.open(io.BytesIO(image))
88+
input_image = my_preprocess(input_image).unsqueeze(0)
89+
90+
if input_image.shape is not None:
91+
if image_tensor is None:
92+
image_tensor = input_image
93+
else:
94+
image_tensor = torch.cat((image_tensor, input_image), 0)
95+
96+
return image_tensor
97+
98+
def inference(self, img):
99+
return self.model.forward(img)
100+
101+
def postprocess(self, inference_output):
102+
num_rows, num_cols = inference_output.shape
103+
output_classes = []
104+
for i in range(num_rows):
105+
out = inference_output[i].unsqueeze(0)
106+
_, y_hat = out.max(1)
107+
predicted_idx = str(y_hat.item())
108+
output_classes.append(self.mapping[predicted_idx])
109+
return output_classes
110+
111+
112+
_service = BatchImageClassifier()
113+
114+
115+
def handle(data, context):
116+
if not _service.initialized:
117+
_service.initialize(context)
118+
119+
if data is None:
120+
return None
121+
122+
data = _service.preprocess(data)
123+
data = _service.inference(data)
124+
data = _service.postprocess(data)
125+
126+
return data

0 commit comments

Comments
 (0)