Skip to content

Commit 0c820c7

Browse files
lxningmreso
andauthored
Support continuous batching in sequence batch streaming case (#3160)
* Support continuous batching in sequence batch streaming case * add test stateful sequence continuous batchng * fmt * fix init atomicboolean * update test and example * fix open session test * fix open session test * set sequnce id * set seq id in response * update test * fix wrong expected result * fixed test expectation * fmt * update test path * simpify * update for comments * remove sequence continuous parametger * update cancel * update cancel * update cleanup * support mix mode stream and non-stream * clean code * update test * fix order * update log * update headers * test mix mode * update fmt * increase counter * increase counter * add commnents * update readme * update readme * Added stop torchserve to unit tests --------- Co-authored-by: Matthias Reso <[email protected]>
1 parent c74a29e commit 0c820c7

24 files changed

+1060
-90
lines changed

examples/stateful/Readme.md examples/stateful/sequence_batching/Readme.md

+7-16
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ Within this context, TorchServe offers a mechanism known as sequence batching. T
66

77
The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by `maxSequenceJobQueueSize`. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend.
88

9-
![sequence batch](../../docs/images/stateful_batch.jpg)
9+
![sequence batch](../../../docs/images/stateful_batch.jpg)
1010

11-
This example serves as a practical showcase of employing stateful inference. Underneath the surface, the backend leverages an [LRU dictionary](https://github.com/amitdev/lru-dict), functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases.
11+
This example serves as a practical showcase of employing stateful inference via sequence batching. Underneath the surface, the backend leverages an [LRU dictionary](https://github.com/amitdev/lru-dict), functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases.
1212

1313
### Step 1: Implement handler
1414

@@ -92,16 +92,10 @@ handler:
9292
### Step 3: Generate mar or tgz file
9393
9494
```bash
95-
torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r requirements.txt --config-file model-config.yaml
95+
torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r ../requirements.txt --config-file model-config.yaml
9696
```
9797

98-
### Step 4: Start torchserve
99-
100-
```bash
101-
torchserve --start --ncs --model-store model_store --models stateful.mar
102-
```
103-
104-
### Step 6: Build GRPC Client
98+
### Step 4: Build GRPC Client
10599
The details can be found at [here](https://github.com/pytorch/serve/blob/master/docs/grpc_api.md).
106100
* Install gRPC python dependencies
107101
```bash
@@ -111,26 +105,23 @@ pip install -U grpcio protobuf grpcio-tools googleapis-common-protos
111105

112106
* Generate python gRPC client stub using the proto files
113107
```bash
114-
cd ../..
108+
cd ../../..
115109
python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
116-
cd -
117110
```
118111

119-
### Step 7: Run inference
112+
### Step 5: Run inference
120113
* Start TorchServe
121114

122115
```bash
123-
torchserve --ncs --start --model-store models --model stateful.mar --ts-config config.properties
116+
torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
124117
```
125118

126119
* Run sequence inference via GRPC client
127120
```bash
128-
cd ../../
129121
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
130122
```
131123

132124
* Run sequence inference via HTTP
133125
```bash
134-
cd ../../
135126
curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt
136127
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Stateful Inference
2+
3+
A stateful model possesses the ability to leverage interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes.
4+
5+
Within this context, TorchServe offers a mechanism known as sequence continuous batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which can be extracted using the "get_sequence_id" function of context.py. This `sequence ID` serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the `sequence ID` when a connection resumes as long as the sequence is not expired on the TorchServe side. Additionally, continuous batching enables a new inference request of a sequence to be served while the previous one is in a response steaming mode.
6+
7+
The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by `maxSequenceJobQueueSize`. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend.
8+
9+
![sequence batch](../../../docs/images/stateful_batch.jpg)
10+
11+
This example serves as a practical showcase of employing stateful inference via sequence batching and continuous batching. Underneath the surface, the backend leverages an [LRU dictionary](https://github.com/amitdev/lru-dict), functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases.
12+
13+
### Step 1: Implement handler
14+
15+
stateful_handler.py is an example of stateful handler. It creates a cache `self.cache` by calling `[LRU](https://github.com/amitdev/lru-dict)`.
16+
17+
```python
18+
def initialize(self, ctx: Context):
19+
"""
20+
Loads the model and Initializes the necessary artifacts
21+
"""
22+
23+
ctx.cache = {}
24+
if ctx.model_yaml_config["handler"] is not None:
25+
self.cache = LRU(
26+
int(
27+
ctx.model_yaml_config["handler"]
28+
.get("cache", {})
29+
.get("capacity", StatefulHandler.DEFAULT_CAPACITY)
30+
)
31+
)
32+
self.initialized = True
33+
```
34+
35+
Handler uses sequenceId (ie., `sequence_id = self.context.get_sequence_id(idx)`) as key to store and fetch values from `self.cache`.
36+
37+
```python
38+
def preprocess(self, data):
39+
"""
40+
Preprocess function to convert the request input to a tensor(Torchserve supported format).
41+
The user needs to override to customize the pre-processing
42+
43+
Args :
44+
data (list): List of the data from the request input.
45+
46+
Returns:
47+
tensor: Returns the tensor data of the input
48+
"""
49+
50+
results = []
51+
for idx, row in enumerate(data):
52+
sequence_id = self.context.get_sequence_id(idx)
53+
# SageMaker sticky router relies on response header to identify the sessions
54+
# The sequence_id from request headers must be set in response headers
55+
self.context.set_response_header(
56+
idx, self.context.header_key_sequence_id, sequence_id
57+
)
58+
59+
# check if sequence_id exists
60+
if self.context.get_request_header(
61+
idx, self.context.header_key_sequence_start
62+
):
63+
prev = int(0)
64+
self.context.cache[sequence_id] = {
65+
"start": True,
66+
"cancel": False,
67+
"end": False,
68+
"num_requests": 0,
69+
}
70+
elif self.cache.has_key(sequence_id):
71+
prev = int(self.cache[sequence_id])
72+
else:
73+
prev = None
74+
logger.error(
75+
f"Not received sequence_start request for sequence_id:{sequence_id} before"
76+
)
77+
78+
req_id = self.context.get_request_id(idx)
79+
# process a new request
80+
if req_id not in self.context.cache:
81+
logger.info(
82+
f"received a new request sequence_id={sequence_id}, request_id={req_id}"
83+
)
84+
request = row.get("data") or row.get("body")
85+
if isinstance(request, (bytes, bytearray)):
86+
request = request.decode("utf-8")
87+
88+
self.context.cache[req_id] = {
89+
"stopping_criteria": self._create_stopping_criteria(
90+
req_id=req_id, seq_id=sequence_id
91+
),
92+
"stream": True,
93+
}
94+
self.context.cache[sequence_id]["num_requests"] += 1
95+
96+
if type(request) is dict and "input" in request:
97+
request = request.get("input")
98+
99+
# -1: cancel
100+
if int(request) == -1:
101+
self.context.cache[sequence_id]["cancel"] = True
102+
self.context.cache[req_id]["stream"] = False
103+
results.append(int(request))
104+
elif prev is None:
105+
logger.info(
106+
f"Close the sequence:{sequence_id} without open session request"
107+
)
108+
self.context.cache[sequence_id]["end"] = True
109+
self.context.cache[req_id]["stream"] = False
110+
self.context.set_response_header(
111+
idx, self.context.header_key_sequence_end, sequence_id
112+
)
113+
results.append(int(request))
114+
else:
115+
val = prev + int(request)
116+
self.cache[sequence_id] = val
117+
# 0: end
118+
if int(request) == 0:
119+
self.context.cache[sequence_id]["end"] = True
120+
self.context.cache[req_id]["stream"] = False
121+
self.context.set_response_header(
122+
idx, self.context.header_key_sequence_end, sequence_id
123+
)
124+
# non stream input:
125+
elif int(request) % 2 == 0:
126+
self.context.cache[req_id]["stream"] = False
127+
128+
results.append(val)
129+
else:
130+
# continue processing stream
131+
logger.info(
132+
f"received continuous request sequence_id={sequence_id}, request_id={req_id}"
133+
)
134+
time.sleep(1)
135+
results.append(prev)
136+
137+
return results
138+
```
139+
140+
### Step 2: Model configuration
141+
142+
Stateful inference has two parameters. TorchServe is able to process (maxWorkers * batchSize) sequences of inference requests of a model in parallel.
143+
* sequenceMaxIdleMSec: the max idle in milliseconds of a sequence inference request of this stateful model. The default value is 0 (ie. this is not a stateful model.) TorchServe does not process the new inference request if the max idle timeout.
144+
* maxSequenceJobQueueSize: the job queue size of an inference sequence of this stateful model. The default value is 1.
145+
146+
147+
```yaml
148+
#cat model-config.yaml
149+
150+
minWorkers: 2
151+
maxWorkers: 2
152+
batchSize: 4
153+
sequenceMaxIdleMSec: 60000
154+
maxSequenceJobQueueSize: 10
155+
sequenceBatching: true
156+
continuousBatching: true
157+
158+
handler:
159+
cache:
160+
capacity: 4
161+
```
162+
163+
### Step 3: Generate mar or tgz file
164+
165+
```bash
166+
torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r ../requirements.txt --config-file model-config.yaml
167+
```
168+
169+
### Step 4: Build GRPC Client
170+
The details can be found at [here](https://github.com/pytorch/serve/blob/master/docs/grpc_api.md).
171+
* Install gRPC python dependencies
172+
```bash
173+
git submodule init
174+
pip install -U grpcio protobuf grpcio-tools googleapis-common-protos
175+
```
176+
177+
* Generate python gRPC client stub using the proto files
178+
```bash
179+
cd ../../..
180+
python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
181+
```
182+
183+
### Step 5: Run inference
184+
* Start TorchServe
185+
186+
```bash
187+
torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
188+
```
189+
190+
* Run sequence inference via GRPC client
191+
```bash
192+
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
193+
```
194+
195+
* Run sequence inference via HTTP
196+
```bash
197+
curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt
198+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
minWorkers: 2
2+
maxWorkers: 2
3+
batchSize: 4
4+
maxNumSequence: 4
5+
sequenceMaxIdleMSec: 10
6+
maxSequenceJobQueueSize: 10
7+
sequenceBatching: true
8+
continuousBatching: true
9+
10+
handler:
11+
cache:
12+
capacity: 4

0 commit comments

Comments
 (0)