|
| 1 | +# Stateful Inference |
| 2 | + |
| 3 | +A stateful model possesses the ability to leverage interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes. |
| 4 | + |
| 5 | +Within this context, TorchServe offers a mechanism known as sequence continuous batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which can be extracted using the "get_sequence_id" function of context.py. This `sequence ID` serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the `sequence ID` when a connection resumes as long as the sequence is not expired on the TorchServe side. Additionally, continuous batching enables a new inference request of a sequence to be served while the previous one is in a response steaming mode. |
| 6 | + |
| 7 | +The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by `maxSequenceJobQueueSize`. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend. |
| 8 | + |
| 9 | + |
| 10 | + |
| 11 | +This example serves as a practical showcase of employing stateful inference via sequence batching and continuous batching. Underneath the surface, the backend leverages an [LRU dictionary](https://github.com/amitdev/lru-dict), functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases. |
| 12 | + |
| 13 | +### Step 1: Implement handler |
| 14 | + |
| 15 | +stateful_handler.py is an example of stateful handler. It creates a cache `self.cache` by calling `[LRU](https://github.com/amitdev/lru-dict)`. |
| 16 | + |
| 17 | +```python |
| 18 | + def initialize(self, ctx: Context): |
| 19 | + """ |
| 20 | + Loads the model and Initializes the necessary artifacts |
| 21 | + """ |
| 22 | + |
| 23 | + ctx.cache = {} |
| 24 | + if ctx.model_yaml_config["handler"] is not None: |
| 25 | + self.cache = LRU( |
| 26 | + int( |
| 27 | + ctx.model_yaml_config["handler"] |
| 28 | + .get("cache", {}) |
| 29 | + .get("capacity", StatefulHandler.DEFAULT_CAPACITY) |
| 30 | + ) |
| 31 | + ) |
| 32 | + self.initialized = True |
| 33 | +``` |
| 34 | + |
| 35 | +Handler uses sequenceId (ie., `sequence_id = self.context.get_sequence_id(idx)`) as key to store and fetch values from `self.cache`. |
| 36 | + |
| 37 | +```python |
| 38 | + def preprocess(self, data): |
| 39 | + """ |
| 40 | + Preprocess function to convert the request input to a tensor(Torchserve supported format). |
| 41 | + The user needs to override to customize the pre-processing |
| 42 | +
|
| 43 | + Args : |
| 44 | + data (list): List of the data from the request input. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + tensor: Returns the tensor data of the input |
| 48 | + """ |
| 49 | + |
| 50 | + results = [] |
| 51 | + for idx, row in enumerate(data): |
| 52 | + sequence_id = self.context.get_sequence_id(idx) |
| 53 | + # SageMaker sticky router relies on response header to identify the sessions |
| 54 | + # The sequence_id from request headers must be set in response headers |
| 55 | + self.context.set_response_header( |
| 56 | + idx, self.context.header_key_sequence_id, sequence_id |
| 57 | + ) |
| 58 | + |
| 59 | + # check if sequence_id exists |
| 60 | + if self.context.get_request_header( |
| 61 | + idx, self.context.header_key_sequence_start |
| 62 | + ): |
| 63 | + prev = int(0) |
| 64 | + self.context.cache[sequence_id] = { |
| 65 | + "start": True, |
| 66 | + "cancel": False, |
| 67 | + "end": False, |
| 68 | + "num_requests": 0, |
| 69 | + } |
| 70 | + elif self.cache.has_key(sequence_id): |
| 71 | + prev = int(self.cache[sequence_id]) |
| 72 | + else: |
| 73 | + prev = None |
| 74 | + logger.error( |
| 75 | + f"Not received sequence_start request for sequence_id:{sequence_id} before" |
| 76 | + ) |
| 77 | + |
| 78 | + req_id = self.context.get_request_id(idx) |
| 79 | + # process a new request |
| 80 | + if req_id not in self.context.cache: |
| 81 | + logger.info( |
| 82 | + f"received a new request sequence_id={sequence_id}, request_id={req_id}" |
| 83 | + ) |
| 84 | + request = row.get("data") or row.get("body") |
| 85 | + if isinstance(request, (bytes, bytearray)): |
| 86 | + request = request.decode("utf-8") |
| 87 | + |
| 88 | + self.context.cache[req_id] = { |
| 89 | + "stopping_criteria": self._create_stopping_criteria( |
| 90 | + req_id=req_id, seq_id=sequence_id |
| 91 | + ), |
| 92 | + "stream": True, |
| 93 | + } |
| 94 | + self.context.cache[sequence_id]["num_requests"] += 1 |
| 95 | + |
| 96 | + if type(request) is dict and "input" in request: |
| 97 | + request = request.get("input") |
| 98 | + |
| 99 | + # -1: cancel |
| 100 | + if int(request) == -1: |
| 101 | + self.context.cache[sequence_id]["cancel"] = True |
| 102 | + self.context.cache[req_id]["stream"] = False |
| 103 | + results.append(int(request)) |
| 104 | + elif prev is None: |
| 105 | + logger.info( |
| 106 | + f"Close the sequence:{sequence_id} without open session request" |
| 107 | + ) |
| 108 | + self.context.cache[sequence_id]["end"] = True |
| 109 | + self.context.cache[req_id]["stream"] = False |
| 110 | + self.context.set_response_header( |
| 111 | + idx, self.context.header_key_sequence_end, sequence_id |
| 112 | + ) |
| 113 | + results.append(int(request)) |
| 114 | + else: |
| 115 | + val = prev + int(request) |
| 116 | + self.cache[sequence_id] = val |
| 117 | + # 0: end |
| 118 | + if int(request) == 0: |
| 119 | + self.context.cache[sequence_id]["end"] = True |
| 120 | + self.context.cache[req_id]["stream"] = False |
| 121 | + self.context.set_response_header( |
| 122 | + idx, self.context.header_key_sequence_end, sequence_id |
| 123 | + ) |
| 124 | + # non stream input: |
| 125 | + elif int(request) % 2 == 0: |
| 126 | + self.context.cache[req_id]["stream"] = False |
| 127 | + |
| 128 | + results.append(val) |
| 129 | + else: |
| 130 | + # continue processing stream |
| 131 | + logger.info( |
| 132 | + f"received continuous request sequence_id={sequence_id}, request_id={req_id}" |
| 133 | + ) |
| 134 | + time.sleep(1) |
| 135 | + results.append(prev) |
| 136 | + |
| 137 | + return results |
| 138 | +``` |
| 139 | + |
| 140 | +### Step 2: Model configuration |
| 141 | + |
| 142 | +Stateful inference has two parameters. TorchServe is able to process (maxWorkers * batchSize) sequences of inference requests of a model in parallel. |
| 143 | +* sequenceMaxIdleMSec: the max idle in milliseconds of a sequence inference request of this stateful model. The default value is 0 (ie. this is not a stateful model.) TorchServe does not process the new inference request if the max idle timeout. |
| 144 | +* maxSequenceJobQueueSize: the job queue size of an inference sequence of this stateful model. The default value is 1. |
| 145 | + |
| 146 | + |
| 147 | +```yaml |
| 148 | +#cat model-config.yaml |
| 149 | + |
| 150 | +minWorkers: 2 |
| 151 | +maxWorkers: 2 |
| 152 | +batchSize: 4 |
| 153 | +sequenceMaxIdleMSec: 60000 |
| 154 | +maxSequenceJobQueueSize: 10 |
| 155 | +sequenceBatching: true |
| 156 | +continuousBatching: true |
| 157 | + |
| 158 | +handler: |
| 159 | + cache: |
| 160 | + capacity: 4 |
| 161 | +``` |
| 162 | +
|
| 163 | +### Step 3: Generate mar or tgz file |
| 164 | +
|
| 165 | +```bash |
| 166 | +torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r ../requirements.txt --config-file model-config.yaml |
| 167 | +``` |
| 168 | + |
| 169 | +### Step 4: Build GRPC Client |
| 170 | +The details can be found at [here](https://github.com/pytorch/serve/blob/master/docs/grpc_api.md). |
| 171 | +* Install gRPC python dependencies |
| 172 | +```bash |
| 173 | +git submodule init |
| 174 | +pip install -U grpcio protobuf grpcio-tools googleapis-common-protos |
| 175 | +``` |
| 176 | + |
| 177 | +* Generate python gRPC client stub using the proto files |
| 178 | +```bash |
| 179 | +cd ../../.. |
| 180 | +python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto |
| 181 | +``` |
| 182 | + |
| 183 | +### Step 5: Run inference |
| 184 | +* Start TorchServe |
| 185 | + |
| 186 | +```bash |
| 187 | +torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties |
| 188 | +``` |
| 189 | + |
| 190 | +* Run sequence inference via GRPC client |
| 191 | +```bash |
| 192 | +python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt |
| 193 | +``` |
| 194 | + |
| 195 | +* Run sequence inference via HTTP |
| 196 | +```bash |
| 197 | +curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt |
| 198 | +``` |
0 commit comments