diff --git a/examples/stateful/Readme.md b/examples/stateful/sequence_batching/Readme.md similarity index 90% rename from examples/stateful/Readme.md rename to examples/stateful/sequence_batching/Readme.md index 66ec77f72a..d1cae6c257 100644 --- a/examples/stateful/Readme.md +++ b/examples/stateful/sequence_batching/Readme.md @@ -6,9 +6,9 @@ Within this context, TorchServe offers a mechanism known as sequence batching. T The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by `maxSequenceJobQueueSize`. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend. -![sequence batch](../../docs/images/stateful_batch.jpg) +![sequence batch](../../../docs/images/stateful_batch.jpg) -This example serves as a practical showcase of employing stateful inference. Underneath the surface, the backend leverages an [LRU dictionary](https://github.com/amitdev/lru-dict), functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases. +This example serves as a practical showcase of employing stateful inference via sequence batching. Underneath the surface, the backend leverages an [LRU dictionary](https://github.com/amitdev/lru-dict), functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases. ### Step 1: Implement handler @@ -92,16 +92,10 @@ handler: ### Step 3: Generate mar or tgz file ```bash -torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r requirements.txt --config-file model-config.yaml +torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r ../requirements.txt --config-file model-config.yaml ``` -### Step 4: Start torchserve - -```bash -torchserve --start --ncs --model-store model_store --models stateful.mar -``` - -### Step 6: Build GRPC Client +### Step 4: Build GRPC Client The details can be found at [here](https://github.com/pytorch/serve/blob/master/docs/grpc_api.md). * Install gRPC python dependencies ```bash @@ -111,26 +105,23 @@ pip install -U grpcio protobuf grpcio-tools googleapis-common-protos * Generate python gRPC client stub using the proto files ```bash -cd ../.. +cd ../../.. python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto -cd - ``` -### Step 7: Run inference +### Step 5: Run inference * Start TorchServe ```bash -torchserve --ncs --start --model-store models --model stateful.mar --ts-config config.properties +torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties ``` * Run sequence inference via GRPC client ```bash -cd ../../ python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt ``` * Run sequence inference via HTTP ```bash -cd ../../ curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt ``` diff --git a/examples/stateful/model-config.yaml b/examples/stateful/sequence_batching/model-config.yaml similarity index 100% rename from examples/stateful/model-config.yaml rename to examples/stateful/sequence_batching/model-config.yaml diff --git a/examples/stateful/stateful_handler.py b/examples/stateful/sequence_batching/stateful_handler.py similarity index 100% rename from examples/stateful/stateful_handler.py rename to examples/stateful/sequence_batching/stateful_handler.py diff --git a/examples/stateful/sequence_continuous_batching/Readme.md b/examples/stateful/sequence_continuous_batching/Readme.md new file mode 100644 index 0000000000..7d4e9a9ed9 --- /dev/null +++ b/examples/stateful/sequence_continuous_batching/Readme.md @@ -0,0 +1,198 @@ +# Stateful Inference + +A stateful model possesses the ability to leverage interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes. + +Within this context, TorchServe offers a mechanism known as sequence continuous batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which can be extracted using the "get_sequence_id" function of context.py. This `sequence ID` serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the `sequence ID` when a connection resumes as long as the sequence is not expired on the TorchServe side. Additionally, continuous batching enables a new inference request of a sequence to be served while the previous one is in a response steaming mode. + +The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by `maxSequenceJobQueueSize`. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend. + +![sequence batch](../../../docs/images/stateful_batch.jpg) + +This example serves as a practical showcase of employing stateful inference via sequence batching and continuous batching. Underneath the surface, the backend leverages an [LRU dictionary](https://github.com/amitdev/lru-dict), functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases. + +### Step 1: Implement handler + +stateful_handler.py is an example of stateful handler. It creates a cache `self.cache` by calling `[LRU](https://github.com/amitdev/lru-dict)`. + +```python + def initialize(self, ctx: Context): + """ + Loads the model and Initializes the necessary artifacts + """ + + ctx.cache = {} + if ctx.model_yaml_config["handler"] is not None: + self.cache = LRU( + int( + ctx.model_yaml_config["handler"] + .get("cache", {}) + .get("capacity", StatefulHandler.DEFAULT_CAPACITY) + ) + ) + self.initialized = True +``` + +Handler uses sequenceId (ie., `sequence_id = self.context.get_sequence_id(idx)`) as key to store and fetch values from `self.cache`. + +```python + def preprocess(self, data): + """ + Preprocess function to convert the request input to a tensor(Torchserve supported format). + The user needs to override to customize the pre-processing + + Args : + data (list): List of the data from the request input. + + Returns: + tensor: Returns the tensor data of the input + """ + + results = [] + for idx, row in enumerate(data): + sequence_id = self.context.get_sequence_id(idx) + # SageMaker sticky router relies on response header to identify the sessions + # The sequence_id from request headers must be set in response headers + self.context.set_response_header( + idx, self.context.header_key_sequence_id, sequence_id + ) + + # check if sequence_id exists + if self.context.get_request_header( + idx, self.context.header_key_sequence_start + ): + prev = int(0) + self.context.cache[sequence_id] = { + "start": True, + "cancel": False, + "end": False, + "num_requests": 0, + } + elif self.cache.has_key(sequence_id): + prev = int(self.cache[sequence_id]) + else: + prev = None + logger.error( + f"Not received sequence_start request for sequence_id:{sequence_id} before" + ) + + req_id = self.context.get_request_id(idx) + # process a new request + if req_id not in self.context.cache: + logger.info( + f"received a new request sequence_id={sequence_id}, request_id={req_id}" + ) + request = row.get("data") or row.get("body") + if isinstance(request, (bytes, bytearray)): + request = request.decode("utf-8") + + self.context.cache[req_id] = { + "stopping_criteria": self._create_stopping_criteria( + req_id=req_id, seq_id=sequence_id + ), + "stream": True, + } + self.context.cache[sequence_id]["num_requests"] += 1 + + if type(request) is dict and "input" in request: + request = request.get("input") + + # -1: cancel + if int(request) == -1: + self.context.cache[sequence_id]["cancel"] = True + self.context.cache[req_id]["stream"] = False + results.append(int(request)) + elif prev is None: + logger.info( + f"Close the sequence:{sequence_id} without open session request" + ) + self.context.cache[sequence_id]["end"] = True + self.context.cache[req_id]["stream"] = False + self.context.set_response_header( + idx, self.context.header_key_sequence_end, sequence_id + ) + results.append(int(request)) + else: + val = prev + int(request) + self.cache[sequence_id] = val + # 0: end + if int(request) == 0: + self.context.cache[sequence_id]["end"] = True + self.context.cache[req_id]["stream"] = False + self.context.set_response_header( + idx, self.context.header_key_sequence_end, sequence_id + ) + # non stream input: + elif int(request) % 2 == 0: + self.context.cache[req_id]["stream"] = False + + results.append(val) + else: + # continue processing stream + logger.info( + f"received continuous request sequence_id={sequence_id}, request_id={req_id}" + ) + time.sleep(1) + results.append(prev) + + return results +``` + +### Step 2: Model configuration + +Stateful inference has two parameters. TorchServe is able to process (maxWorkers * batchSize) sequences of inference requests of a model in parallel. +* sequenceMaxIdleMSec: the max idle in milliseconds of a sequence inference request of this stateful model. The default value is 0 (ie. this is not a stateful model.) TorchServe does not process the new inference request if the max idle timeout. +* maxSequenceJobQueueSize: the job queue size of an inference sequence of this stateful model. The default value is 1. + + +```yaml +#cat model-config.yaml + +minWorkers: 2 +maxWorkers: 2 +batchSize: 4 +sequenceMaxIdleMSec: 60000 +maxSequenceJobQueueSize: 10 +sequenceBatching: true +continuousBatching: true + +handler: + cache: + capacity: 4 +``` + +### Step 3: Generate mar or tgz file + +```bash +torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r ../requirements.txt --config-file model-config.yaml +``` + +### Step 4: Build GRPC Client +The details can be found at [here](https://github.com/pytorch/serve/blob/master/docs/grpc_api.md). +* Install gRPC python dependencies +```bash +git submodule init +pip install -U grpcio protobuf grpcio-tools googleapis-common-protos +``` + +* Generate python gRPC client stub using the proto files +```bash +cd ../../.. +python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto +``` + +### Step 5: Run inference +* Start TorchServe + +```bash +torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties +``` + +* Run sequence inference via GRPC client +```bash +python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt +``` + +* Run sequence inference via HTTP +```bash +curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt +``` diff --git a/examples/stateful/sequence_continuous_batching/model-config.yaml b/examples/stateful/sequence_continuous_batching/model-config.yaml new file mode 100644 index 0000000000..1597308e9d --- /dev/null +++ b/examples/stateful/sequence_continuous_batching/model-config.yaml @@ -0,0 +1,12 @@ +minWorkers: 2 +maxWorkers: 2 +batchSize: 4 +maxNumSequence: 4 +sequenceMaxIdleMSec: 10 +maxSequenceJobQueueSize: 10 +sequenceBatching: true +continuousBatching: true + +handler: + cache: + capacity: 4 diff --git a/examples/stateful/sequence_continuous_batching/stateful_handler.py b/examples/stateful/sequence_continuous_batching/stateful_handler.py new file mode 100644 index 0000000000..36d58c24c9 --- /dev/null +++ b/examples/stateful/sequence_continuous_batching/stateful_handler.py @@ -0,0 +1,226 @@ +import logging +import time +from abc import ABC + +from lru import LRU + +from ts.context import Context +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) + + +class StatefulHandler(BaseHandler, ABC): + DEFAULT_CAPACITY = 10 + + def __init__(self): + super().__init__() + self.cache: LRU = None + + def initialize(self, ctx: Context): + """ + Loads the model and Initializes the necessary artifacts + """ + + # context cache includes 2 types of keys + # key1: sequence_id + # value is a dict which records the sequence's status: start, end, cancel, number of the requests in this batch. + # + # key2: request_id + # value is a dict which records a request's streaming status: + # None(ie. non response streaming request), True or False (ie. streaming complete or not) + ctx.cache = {} + if ctx.model_yaml_config["handler"] is not None: + self.cache = LRU( + int( + ctx.model_yaml_config["handler"] + .get("cache", {}) + .get("capacity", StatefulHandler.DEFAULT_CAPACITY) + ) + ) + + self.initialized = True + + def preprocess(self, data): + """ + Preprocess function to convert the request input to a tensor(Torchserve supported format). + The user needs to override to customize the pre-processing + + Args : + data (list): List of the data from the request input. + + Returns: + tensor: Returns the tensor data of the input + """ + + results = [] + for idx, row in enumerate(data): + sequence_id = self.context.get_sequence_id(idx) + # SageMaker sticky router relies on response header to identify the sessions + # The sequence_id from request headers must be set in response headers + self.context.set_response_header( + idx, self.context.header_key_sequence_id, sequence_id + ) + + # check if sequence_id exists + if self.context.get_request_header( + idx, self.context.header_key_sequence_start + ): + prev = int(0) + self.context.cache[sequence_id] = { + "start": True, + "cancel": False, + "end": False, + "num_requests": 0, + } + elif self.cache.has_key(sequence_id): + prev = int(self.cache[sequence_id]) + else: + prev = None + logger.error( + f"Not received sequence_start request for sequence_id:{sequence_id} before" + ) + + req_id = self.context.get_request_id(idx) + # process a new request + if req_id not in self.context.cache: + logger.info( + f"received a new request sequence_id={sequence_id}, request_id={req_id}" + ) + request = row.get("data") or row.get("body") + if isinstance(request, (bytes, bytearray)): + request = request.decode("utf-8") + + self.context.cache[req_id] = { + "stopping_criteria": self._create_stopping_criteria( + req_id=req_id, seq_id=sequence_id + ), + "stream": True, + } + self.context.cache[sequence_id]["num_requests"] += 1 + + if type(request) is dict and "input" in request: + request = request.get("input") + + # -1: cancel + if int(request) == -1: + self.context.cache[sequence_id]["cancel"] = True + self.context.cache[req_id]["stream"] = False + results.append(int(request)) + elif prev is None: + logger.info( + f"Close the sequence:{sequence_id} without open session request" + ) + self.context.cache[sequence_id]["end"] = True + self.context.cache[req_id]["stream"] = False + self.context.set_response_header( + idx, self.context.header_key_sequence_end, sequence_id + ) + results.append(int(request)) + else: + val = prev + int(request) + self.cache[sequence_id] = val + # 0: end + if int(request) == 0: + self.context.cache[sequence_id]["end"] = True + self.context.cache[req_id]["stream"] = False + self.context.set_response_header( + idx, self.context.header_key_sequence_end, sequence_id + ) + # non stream input: + elif int(request) % 2 == 0: + self.context.cache[req_id]["stream"] = False + + results.append(val) + else: + # continue processing stream + logger.info( + f"received continuous request sequence_id={sequence_id}, request_id={req_id}" + ) + time.sleep(1) + results.append(prev) + + return results + + def inference(self, data, *args, **kwargs): + return data + + def postprocess(self, data): + """ + The post process function makes use of the output from the inference and converts into a + Torchserve supported response output. + + Returns: + List: The post process function returns a list of the predicted output. + """ + self.context.stopping_criteria = [ + self.context.cache[req_id]["stopping_criteria"] + for req_id in self.context.request_ids.values() + ] + + return data + + def clean_up(self, seq_id, req_id, del_seq): + # clean up + self.context.cache[seq_id]["num_requests"] -= 1 + if self.context.cache[seq_id]["num_requests"] == 0 and del_seq: + del self.context.cache[seq_id] + del self.context.cache[req_id] + + def _create_stopping_criteria(self, req_id, seq_id): + class StoppingCriteria(object): + def __init__(self, outer, req_id, seq_id): + self.req_id = req_id + self.seq_id = seq_id + self.outer = outer + self.counter = 10 + + def __call__(self, res): + # sequence end + if self.outer.context.cache[seq_id]["end"]: + ret = True if self.outer.context.cache[req_id]["stream"] else None + self.outer.clean_up(self.seq_id, self.req_id, True) + logger.info(f"end sequence_id={self.seq_id}, ret={ret}") + return ret + # cancel + elif self.outer.context.cache[seq_id]["cancel"]: + ret = True if self.outer.context.cache[req_id]["stream"] else None + self.outer.clean_up(self.seq_id, self.req_id, False) + logger.info( + f"cancel sequence_id={self.seq_id}, request_id={self.req_id}, ret={ret}" + ) + if self.outer.context.cache[seq_id]["num_requests"] == 0: + self.outer.context.cache[seq_id]["cancel"] = False + return ret + # start + elif self.outer.context.cache[seq_id]["start"]: + self.outer.clean_up(self.seq_id, self.req_id, False) + logger.info( + f"start sequence_id={self.seq_id}, request_id={self.req_id}, ret=None" + ) + self.outer.context.cache[seq_id]["start"] = False + return None + # non stream + elif not self.outer.context.cache[req_id]["stream"]: + self.outer.clean_up(self.seq_id, self.req_id, False) + logger.info( + f"test non stream sequence_id={self.seq_id}, request_id={self.req_id}, ret=None" + ) + return None + # stream complete + elif self.counter == 0: + self.outer.clean_up(self.seq_id, self.req_id, False) + logger.info( + f"finish sequence_id={self.seq_id}, request_id={self.req_id}, ret=True" + ) + return True + # stream running + else: + self.counter -= 1 + logger.info( + f"continue sequence_id={self.seq_id}, request_id={self.req_id}, ret=False" + ) + + return False + + return StoppingCriteria(outer=self, req_id=req_id, seq_id=seq_id) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index 16a1cd7d8d..000cafc8f5 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -60,7 +60,7 @@ public class ModelConfig { */ private long sequenceMaxIdleMSec; /** - * the job queue size of an inference sequence of this stateful model. The default value is 1. + * the job queue size of one inference sequence of this stateful model. The default value is 1. */ private int maxSequenceJobQueueSize = 1; /** the max number of sequences can be accepted. The default value is 1. */ diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/InferenceImpl.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/InferenceImpl.java index 07c0569ebc..7ba95951c8 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/InferenceImpl.java +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/InferenceImpl.java @@ -229,7 +229,7 @@ private void prediction( if (workerCmd == WorkerCommands.STREAMPREDICT2) { String sequenceId = request.getSequenceId(); if ("".equals(sequenceId)) { - sequenceId = String.format("ts-%s", UUID.randomUUID()); + sequenceId = String.format("ts-seq-%s", UUID.randomUUID()); inputData.updateHeaders( ConfigManager.getInstance().getTsHeaderKeySequenceStart(), "true"); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java index 363717cb7f..2bb9102969 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java @@ -13,6 +13,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.UUID; import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.model.ModelNotFoundException; @@ -255,6 +256,15 @@ private void predict( throw new ModelNotFoundException("Model not found: " + modelName); } input.setClientExpireTS(model.getClientTimeoutInMills()); + if (model.isSequenceBatching()) { + String sequenceId = input.getSequenceId(); + if ("".equals(sequenceId)) { + sequenceId = String.format("ts-seq-%s", UUID.randomUUID()); + input.updateHeaders( + ConfigManager.getInstance().getTsHeaderKeySequenceStart(), "true"); + } + input.updateHeaders(ConfigManager.getInstance().getTsHeaderKeySequenceId(), sequenceId); + } if (HttpMethod.OPTIONS.equals(req.method())) { String resp = OpenApiUtils.getModelApi(model); diff --git a/frontend/server/src/main/java/org/pytorch/serve/job/JobGroup.java b/frontend/server/src/main/java/org/pytorch/serve/job/JobGroup.java index 4b6685d420..75c2b1b5da 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/job/JobGroup.java +++ b/frontend/server/src/main/java/org/pytorch/serve/job/JobGroup.java @@ -2,6 +2,7 @@ import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -10,13 +11,13 @@ public class JobGroup { String groupId; LinkedBlockingDeque jobs; int maxJobQueueSize; - boolean finished; + AtomicBoolean finished; public JobGroup(String groupId, int maxJobQueueSize) { this.groupId = groupId; this.maxJobQueueSize = maxJobQueueSize; this.jobs = new LinkedBlockingDeque<>(maxJobQueueSize); - this.finished = false; + this.finished = new AtomicBoolean(false); } public boolean appendJob(Job job) { @@ -24,7 +25,7 @@ public boolean appendJob(Job job) { } public Job pollJob(long timeout) { - if (finished) { + if (finished.get()) { return null; } try { @@ -40,10 +41,10 @@ public String getGroupId() { } public void setFinished(boolean sequenceEnd) { - this.finished = sequenceEnd; + this.finished.set(sequenceEnd); } public boolean isFinished() { - return this.finished; + return this.finished.get(); } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java index 8074f37b0b..2471c0c7ba 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java @@ -9,7 +9,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.function.Function; @@ -443,17 +442,6 @@ private static DescribeModelResponse createModelResponse( public static RestJob addRESTInferenceJob( ChannelHandlerContext ctx, String modelName, String version, RequestInput input) throws ModelNotFoundException, ModelVersionNotFoundException { - String sequenceStart; - if ((sequenceStart = - input.getHeaders() - .get(ConfigManager.getInstance().getTsHeaderKeySequenceStart())) - != null) { - if (Boolean.parseBoolean(sequenceStart.toLowerCase())) { - String sequenceId = String.format("ts-%s", UUID.randomUUID()); - input.updateHeaders( - ConfigManager.getInstance().getTsHeaderKeySequenceId(), sequenceId); - } - } RestJob job = new RestJob(ctx, modelName, version, WorkerCommands.PREDICT, input); if (!ModelManager.getInstance().addJob(job)) { String responseMessage = getStreamingInferenceErrorResponseMessage(modelName, version); diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java b/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java index 1f89f4a48a..dbd987b700 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java @@ -76,18 +76,17 @@ private void encodeRequest(RequestInput req, ByteBuf out) { out.writeInt(buf.length); out.writeBytes(buf); - if (req.isCachedInBackend()) { - out.writeInt(-1); // End of List - out.writeInt(-1); // End of List - return; - } - for (Map.Entry entry : req.getHeaders().entrySet()) { encodeField(entry.getKey(), out); encodeField(entry.getValue(), out); } out.writeInt(-1); // End of List + if (req.isCachedInBackend()) { + out.writeInt(-1); // End of List + return; + } + for (InputParameter input : req.getParameters()) { encodeParameter(input, out); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java b/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java index f1f7a4dd77..0c4c1b6f92 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java @@ -21,6 +21,7 @@ public RequestInput(String requestId) { headers = new HashMap<>(); parameters = new ArrayList<>(); clientExpireTS = Long.MAX_VALUE; // default(never expire): Long.MAX_VALUE + sequenceId = ""; } public String getRequestId() { @@ -41,6 +42,9 @@ public void setHeaders(Map headers) { public void updateHeaders(String key, String val) { headers.put(key, val); + if (ConfigManager.getInstance().getTsHeaderKeySequenceId().equals(key)) { + setSequenceId(val); + } } public List getParameters() { @@ -75,10 +79,10 @@ public void setClientExpireTS(long clientTimeoutInMills) { } public String getSequenceId() { - if (sequenceId == null) { + if (sequenceId.isEmpty()) { sequenceId = headers.getOrDefault( - ConfigManager.getInstance().getTsHeaderKeySequenceId(), null); + ConfigManager.getInstance().getTsHeaderKeySequenceId(), ""); } return sequenceId; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java index e38e51dcba..a850386a0c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java @@ -190,4 +190,8 @@ public void pollBatch(String threadName, WorkerState state) public void shutdown() { return; } + + public void startEventDispatcher() { + return; + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ContinuousBatching.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ContinuousBatching.java index a925050b97..51550b0052 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ContinuousBatching.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ContinuousBatching.java @@ -19,6 +19,7 @@ public ContinuousBatching(Model model) { super(model); } + @Override public BaseModelRequest getRequest(String threadName, WorkerState state) throws InterruptedException, ExecutionException { int batchQuota = model.getBatchSize() - jobs.size(); @@ -60,6 +61,7 @@ public BaseModelRequest getRequest(String threadName, WorkerState state) * @return - true: either a non-stream response or last stream response is sent - false: a * stream response (not include the last stream) is sent */ + @Override public boolean sendResponse(ModelWorkerResponse message) { // TODO: Handle prediction level code if (message.getCode() == 200) { @@ -98,7 +100,7 @@ public boolean sendResponse(ModelWorkerResponse message) { prediction .getHeaders() .get(org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT); - if (streamNext != null && streamNext.equals("false")) { + if (streamNext == null || (streamNext != null && streamNext.equals("false"))) { jobs.remove(jobId); } else if (!job.isOpen()) { jobs.remove(job.getJobId()); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index 96ee3c8ae3..8a52a3f7d4 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -382,7 +382,8 @@ public void pollInferJob( } long begin = System.currentTimeMillis(); - for (int i = 0; i < batchSize - 1; ++i) { + batchSize = pollNoWait ? batchSize : batchSize - 1; + for (int i = 0; i < batchSize; ++i) { if (pollNoWait) { j = jobsQueue.poll(); } else { diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatchAggregator.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java similarity index 88% rename from frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatchAggregator.java rename to frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java index df6b3d1c4a..8bd69378ab 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatchAggregator.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java @@ -11,34 +11,33 @@ import java.util.concurrent.atomic.AtomicBoolean; import org.pytorch.serve.job.Job; import org.pytorch.serve.job.JobGroup; -import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.messages.BaseModelRequest; import org.pytorch.serve.util.messages.ModelWorkerResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class SequenceBatchAggregator extends BatchAggregator { +public class SequenceBatching extends BatchAggregator { - private static final Logger logger = LoggerFactory.getLogger(SequenceBatchAggregator.class); + private static final Logger logger = LoggerFactory.getLogger(SequenceBatching.class); private ExecutorService pollExecutors; /** - * eventJobGroupIds is an queue in EventDispatcher. It's item has 2 cases. - empty string: - * trigger EventDispatcher to fetch new job groups. - job group id: trigger EventDispatcher to + * eventJobGroupIds is an queue in EventDispatcher. It's item has 2 cases. 1) empty string: + * trigger EventDispatcher to fetch new job groups. 2) job group id: trigger EventDispatcher to * fetch a new job from this jobGroup. */ - private LinkedBlockingDeque eventJobGroupIds; + protected LinkedBlockingDeque eventJobGroupIds; // A queue holds jobs ready for this aggregator to add into a batch. Each job of this queue is - // from distinct jobGroup. - private LinkedBlockingDeque jobsQueue; + // from distinct jobGroup. jobs + protected LinkedBlockingDeque jobsQueue; private Thread eventDispatcher; private AtomicBoolean isPollJobGroup; // A list of jobGroupIds which are added into current batch. These jobGroupIds need to be added // back to eventJobGroupIds once their jobs are processed by a batch. - private LinkedList currentJobGroupIds; + protected LinkedList currentJobGroupIds; private int localCapacity; private AtomicBoolean running = new AtomicBoolean(true); - public SequenceBatchAggregator(Model model) { + public SequenceBatching(Model model) { super(model); this.currentJobGroupIds = new LinkedList<>(); this.pollExecutors = Executors.newFixedThreadPool(model.getBatchSize() + 1); @@ -51,6 +50,7 @@ public SequenceBatchAggregator(Model model) { this.eventDispatcher.start(); } + @Override public void startEventDispatcher() { this.eventDispatcher.start(); } @@ -83,7 +83,7 @@ private void pollJobGroup() throws InterruptedException { isPollJobGroup.set(false); } - private void pollInferJob() throws InterruptedException { + protected void pollInferJob() throws InterruptedException { model.pollInferJob(jobs, model.getBatchSize(), jobsQueue); for (Job job : jobs.values()) { @@ -220,21 +220,16 @@ public void run() { private void pollJobFromJobGroup(String jobGroupId) { // Poll a job from a jobGroup JobGroup jobGroup = model.getJobGroup(jobGroupId); - Job job = jobGroup.pollJob(model.getSequenceMaxIdleMSec()); + Job job = null; + if (!jobGroup.isFinished()) { + job = jobGroup.pollJob(model.getSequenceMaxIdleMSec()); + } if (job == null) { // JobGroup expired, clean it. cleanJobGroup(jobGroupId); // intent to add new job groups. eventJobGroupIds.add(""); } else { - if (Boolean.parseBoolean( - job.getPayload() - .getHeaders() - .getOrDefault( - ConfigManager.getInstance().getTsHeaderKeySequenceEnd(), - "false"))) { - jobGroup.setFinished(true); - } jobsQueue.add(job); } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceContinuousBatching.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceContinuousBatching.java new file mode 100644 index 0000000000..4ec5e4747c --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceContinuousBatching.java @@ -0,0 +1,179 @@ +package org.pytorch.serve.wlm; + +import java.util.Map; +import java.util.concurrent.ExecutionException; +import org.pytorch.serve.job.Job; +import org.pytorch.serve.job.JobGroup; +import org.pytorch.serve.util.ConfigManager; +import org.pytorch.serve.util.messages.BaseModelRequest; +import org.pytorch.serve.util.messages.ModelInferenceRequest; +import org.pytorch.serve.util.messages.ModelLoadModelRequest; +import org.pytorch.serve.util.messages.ModelWorkerResponse; +import org.pytorch.serve.util.messages.Predictions; +import org.pytorch.serve.util.messages.RequestInput; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SequenceContinuousBatching extends SequenceBatching { + private static final Logger logger = LoggerFactory.getLogger(SequenceContinuousBatching.class); + + public SequenceContinuousBatching(Model model) { + super(model); + } + + @Override + public BaseModelRequest getRequest(String threadName, WorkerState state) + throws InterruptedException, ExecutionException { + + ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName()); + + pollBatch(threadName, state); + + if (model.isUseJobTicket() && jobs.isEmpty()) { + model.decNumJobTickets(); + return req; + } + + for (Job j : jobs.values()) { + if (j.isControlCmd()) { + if (jobs.size() > 1) { + throw new IllegalStateException( + "Received more than 1 control command. " + + "Control messages should be processed/retrieved one at a time."); + } + RequestInput input = j.getPayload(); + int gpuId = -1; + String gpu = input.getStringParameter("gpu"); + if (gpu != null) { + gpuId = Integer.parseInt(gpu); + } + return new ModelLoadModelRequest(model, gpuId); + } else { + req.setCommand(j.getCmd()); + j.setScheduled(); + req.addRequest(j.getPayload()); + } + } + return req; + } + + /** + * @param message: a response of a batch inference requests + * @return - true: either a non-stream response or last stream response is sent - false: a + * stream response (not include the last stream) is sent This is a copy of sendResponse from + * ContinuousBatching + 1. setJobGroupFinished: handle a list of jobGroups end. 2. + * resetCurrentJobGroupIds + */ + @Override + public boolean sendResponse(ModelWorkerResponse message) { + // TODO: Handle prediction level code + if (message.getCode() == 200) { + if (message.getPredictions().isEmpty()) { + // The jobs size is always 1 in the case control command + for (Map.Entry j : jobs.entrySet()) { + Job job = j.getValue(); + if (job.isControlCmd()) { + cleanJobs(); + return true; + } + } + } + for (Predictions prediction : message.getPredictions()) { + String jobId = prediction.getRequestId(); + Job job = jobs.get(jobId); + + if (job == null) { + throw new IllegalStateException( + "Unexpected job in sendResponse() with 200 status code: " + jobId); + } + + if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) { + job.response( + prediction.getResp(), + prediction.getContentType(), + prediction.getStatusCode(), + prediction.getReasonPhrase(), + prediction.getHeaders()); + } else { + logger.warn( + "Drop response for inference request {} due to client timeout", + job.getPayload().getRequestId()); + } + String streamNext = + prediction + .getHeaders() + .get(org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT); + if (streamNext == null || (streamNext != null && streamNext.equals("false"))) { + jobs.remove(jobId); + } else if (!job.isOpen()) { + jobs.remove(job.getJobId()); + logger.info( + "Connection to client got closed; Removing job: {}", + job.getPayload().getRequestId()); + } else { + job.getPayload().setCachedInBackend(true); + } + setJobGroupFinished(prediction); + } + } else { + for (Map.Entry j : jobs.entrySet()) { + if (j.getValue() == null) { + throw new IllegalStateException( + "Unexpected job in sendResponse() with non 200 status code: " + + j.getKey()); + } + Job job = j.getValue(); + if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) { + job.sendError(message.getCode(), message.getMessage()); + } else { + logger.warn( + "Drop error response for inference request {} due to client timeout", + job.getPayload().getRequestId()); + } + } + cleanJobs(); + } + + resetCurrentJobGroupIds(); + + return true; + } + + private void setJobGroupFinished(Predictions prediction) { + String val = + prediction + .getHeaders() + .getOrDefault( + ConfigManager.getInstance().getTsHeaderKeySequenceEnd(), null); + if (val != null) { + String[] jobGroupIds = val.split(";"); + for (String j : jobGroupIds) { + String jobGroupId = j.trim(); + JobGroup jobGroup = model.getJobGroup(jobGroupId); + if (jobGroup != null) { + jobGroup.setFinished(true); + } + } + } + } + + @Override + protected void pollInferJob() throws InterruptedException { + // TBD: Temporarily hard code the continuous batch size is 2 * batchSize + model.pollInferJob(jobs, model.getBatchSize() * 2 - jobs.size(), jobsQueue); + + for (Job job : jobs.values()) { + if (job.getGroupId() != null) { + currentJobGroupIds.add(job.getGroupId()); + } + } + } + + private void resetCurrentJobGroupIds() { + if (!currentJobGroupIds.isEmpty()) { + eventJobGroupIds.addAll(currentJobGroupIds); + currentJobGroupIds.clear(); + } + return; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index 604007c3a5..2c2a3b6e12 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -229,8 +229,10 @@ private void addThreads( BatchAggregator aggregator; - if (model.isSequenceBatching()) { - aggregator = new SequenceBatchAggregator(model); + if (model.isSequenceBatching() && model.isContinuousBatching()) { + aggregator = new SequenceContinuousBatching(model); + } else if (model.isSequenceBatching()) { + aggregator = new SequenceBatching(model); } else if (model.isContinuousBatching()) { aggregator = new ContinuousBatching(model); } else { diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index bcd6a646b4..dd5b3f1dc9 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -541,9 +541,8 @@ public void retry() { if (backoffIdx < BACK_OFF.length - 1) { ++backoffIdx; } - if (aggregator instanceof SequenceBatchAggregator) { - ((SequenceBatchAggregator) aggregator).startEventDispatcher(); - } + aggregator.startEventDispatcher(); + manager.getScheduler() .schedule(() -> manager.submitTask(this), BACK_OFF[backoffIdx], TimeUnit.SECONDS); logger.info("Retry worker: {} in {} seconds.", workerId, BACK_OFF[backoffIdx]); diff --git a/test/postman/inference_stream2_data.json b/test/postman/inference_stream2_data.json index 31a72664cf..cf5b34ac2b 100644 --- a/test/postman/inference_stream2_data.json +++ b/test/postman/inference_stream2_data.json @@ -3,9 +3,9 @@ "model_name":"stateful", "model_file": "../../examples/stateful/model.py", "serialized_file": "../../examples/stateful/model_cnn.pt", - "handler": "../../examples/stateful/stateful_handler.py", + "handler": "../../examples/stateful/sequence_batching/stateful_handler.py", "requirements_file": "../../examples/stateful/requirements.txt", - "config_file": "../../examples/stateful/model-config.yaml", + "config_file": "../../examples/stateful/sequence_batching/model-config.yaml", "archive_format": "no-archive", "worker":2, "synchronous":"true", diff --git a/test/pytest/test_example_stateful_http.py b/test/pytest/test_example_stateful_sequence_batching_http.py similarity index 90% rename from test/pytest/test_example_stateful_http.py rename to test/pytest/test_example_stateful_sequence_batching_http.py index afbe0f05fe..65b996b2f6 100644 --- a/test/pytest/test_example_stateful_http.py +++ b/test/pytest/test_example_stateful_sequence_batching_http.py @@ -10,6 +10,9 @@ CURR_FILE_PATH = Path(__file__).parent STATEFUL_PATH = CURR_FILE_PATH.parents[1] / "examples" / "stateful" +STATEFUL_SEQUENCE_PATH = ( + CURR_FILE_PATH.parents[1] / "examples" / "stateful" / "sequence_batching" +) CONFIG_PROPERTIES_PATH = CURR_FILE_PATH.parents[1] / "test" / "config_ts.properties" YAML_CONFIG = f""" @@ -27,22 +30,10 @@ capacity: 4 """ -PROMPTS = [ - { - "prompt": "A robot may not injure a human being", - "max_new_tokens": 50, - "temperature": 0.8, - "logprobs": 1, - "prompt_logprobs": 1, - "max_tokens": 128, - "adapter": "adapter_1", - }, -] - @pytest.fixture def add_paths(): - sys.path.append(STATEFUL_PATH.as_posix()) + sys.path.append(STATEFUL_SEQUENCE_PATH.as_posix()) yield sys.path.pop() @@ -67,7 +58,7 @@ def create_mar_file(work_dir, model_archiver, model_name, request): config = ModelArchiverConfig( model_name=model_name, version="1.0", - handler=(STATEFUL_PATH / "stateful_handler.py").as_posix(), + handler=(STATEFUL_SEQUENCE_PATH / "stateful_handler.py").as_posix(), serialized_file=(STATEFUL_PATH / "model_cnn.pt").as_posix(), model_file=(STATEFUL_PATH / "model.py").as_posix(), export_path=work_dir, @@ -140,6 +131,7 @@ def test_stateful_mar(mar_file_path, model_store): # Clean up files shutil.rmtree(Path(model_store) / model_name) + test_utils.stop_torchserve() def __infer_stateful(model_name, sequence_id, expected): diff --git a/test/pytest/test_example_stateful_sequence_continuous_batching_http.py b/test/pytest/test_example_stateful_sequence_continuous_batching_http.py new file mode 100644 index 0000000000..2e95735ccf --- /dev/null +++ b/test/pytest/test_example_stateful_sequence_continuous_batching_http.py @@ -0,0 +1,367 @@ +import shutil +import sys +import threading +import time +from pathlib import Path + +import pytest +import requests +import test_utils +from model_archiver.model_archiver_config import ModelArchiverConfig + +CURR_FILE_PATH = Path(__file__).parent +STATEFUL_PATH = CURR_FILE_PATH.parents[1] / "examples" / "stateful" +STATEFUL_SEQUENCE_CONTINUOUS_PATH = ( + CURR_FILE_PATH.parents[1] / "examples" / "stateful" / "sequence_continuous_batching" +) +CONFIG_PROPERTIES_PATH = CURR_FILE_PATH.parents[1] / "test" / "config_ts.properties" + +YAML_CONFIG = f""" +# TorchServe frontend parameters +minWorkers: 2 +maxWorkers: 2 +batchSize: 1 +maxNumSequence: 2 +sequenceMaxIdleMSec: 5000 +maxSequenceJobQueueSize: 10 +sequenceBatching: true +continuousBatching: true + +handler: + cache: + capacity: 4 +""" + +JSON_INPUT = { + "input": 3, +} + + +@pytest.fixture +def add_paths(): + sys.path.append(STATEFUL_SEQUENCE_CONTINUOUS_PATH.as_posix()) + yield + sys.path.pop() + + +@pytest.fixture(scope="module") +def model_name(): + yield "stateful" + + +@pytest.fixture(scope="module") +def work_dir(tmp_path_factory, model_name): + return tmp_path_factory.mktemp(model_name) + + +@pytest.fixture(scope="module", name="mar_file_path") +def create_mar_file(work_dir, model_archiver, model_name, request): + mar_file_path = Path(work_dir).joinpath(model_name) + + model_config_yaml = Path(work_dir) / "model-config.yaml" + model_config_yaml.write_text(YAML_CONFIG) + + config = ModelArchiverConfig( + model_name=model_name, + version="1.0", + handler=(STATEFUL_SEQUENCE_CONTINUOUS_PATH / "stateful_handler.py").as_posix(), + serialized_file=(STATEFUL_PATH / "model_cnn.pt").as_posix(), + model_file=(STATEFUL_PATH / "model.py").as_posix(), + export_path=work_dir, + requirements_file=(STATEFUL_PATH / "requirements.txt").as_posix(), + runtime="python", + force=False, + config_file=model_config_yaml.as_posix(), + archive_format="no-archive", + ) + + model_archiver.generate_model_archive(config) + + assert mar_file_path.exists() + + yield mar_file_path.as_posix() + + # Clean up files + shutil.rmtree(mar_file_path) + + +def test_infer_stateful(mar_file_path, model_store): + """ + Register the model in torchserve + """ + + file_name = Path(mar_file_path).name + + model_name = Path(file_name).stem + + shutil.copytree(mar_file_path, Path(model_store) / model_name) + + params = ( + ("model_name", model_name), + ("url", Path(model_store) / model_name), + ("initial_workers", "2"), + ("synchronous", "true"), + ) + + test_utils.start_torchserve( + model_store=model_store, snapshot_file=CONFIG_PROPERTIES_PATH, gen_mar=False + ) + + try: + test_utils.reg_resp = test_utils.register_model_with_params(params) + + t0 = threading.Thread( + target=__infer_stateful, + args=( + model_name, + "seq_0", + "2 6 12 20 30", + ), + ) + t1 = threading.Thread( + target=__infer_stateful, + args=( + model_name, + "seq_1", + "4 12 24 40 60", + ), + ) + + t0.start() + t1.start() + + t0.join() + t1.join() + finally: + test_utils.unregister_model(model_name) + + # Clean up files + shutil.rmtree(Path(model_store) / model_name) + test_utils.stop_torchserve() + + +def test_infer_stateful_end(mar_file_path, model_store): + """ + Register the model in torchserve + """ + + file_name = Path(mar_file_path).name + + model_name = Path(file_name).stem + + shutil.copytree(mar_file_path, Path(model_store) / model_name) + + params = ( + ("model_name", model_name), + ("url", Path(model_store) / model_name), + ("initial_workers", "2"), + ("synchronous", "true"), + ) + + test_utils.start_torchserve( + model_store=model_store, snapshot_file=CONFIG_PROPERTIES_PATH, gen_mar=False + ) + + try: + test_utils.reg_resp = test_utils.register_model_with_params(params) + + t0 = threading.Thread( + target=__infer_stateful_end, + args=( + model_name, + "seq_0", + "2 6 12 20 20", + ), + ) + t1 = threading.Thread( + target=__infer_stateful, + args=( + model_name, + "seq_1", + "4 12 24 40 60", + ), + ) + + t0.start() + t1.start() + + t0.join() + t1.join() + finally: + test_utils.unregister_model(model_name) + + # Clean up files + shutil.rmtree(Path(model_store) / model_name) + test_utils.stop_torchserve() + + +def test_infer_stateful_cancel(mar_file_path, model_store): + """ + Register the model in torchserve + """ + + file_name = Path(mar_file_path).name + + model_name = Path(file_name).stem + + shutil.copytree(mar_file_path, Path(model_store) / model_name) + + params = ( + ("model_name", model_name), + ("url", Path(model_store) / model_name), + ("initial_workers", "2"), + ("synchronous", "true"), + ) + + test_utils.start_torchserve( + model_store=model_store, snapshot_file=CONFIG_PROPERTIES_PATH, gen_mar=False + ) + + try: + test_utils.reg_resp = test_utils.register_model_with_params(params) + with requests.post( + url=f"http://localhost:8080/predictions/{model_name}", + data=str(2).encode(), + ) as response: + s_id = response.headers.get("ts_request_sequence_id") + headers = { + "ts_request_sequence_id": s_id, + } + + t0 = threading.Thread( + target=__infer_stateful_cancel, + args=( + model_name, + False, + headers, + "5", + ), + ) + t1 = threading.Thread( + target=__infer_stateful_cancel, + args=( + model_name, + True, + headers, + "-1", + ), + ) + + t0.start() + t1.start() + + t0.join() + t1.join() + finally: + test_utils.unregister_model(model_name) + + # Clean up files + shutil.rmtree(Path(model_store) / model_name) + test_utils.stop_torchserve() + + +def __infer_stateful(model_name, sequence_id, expected): + start = True + prediction = [] + for idx in range(5): + if sequence_id == "seq_0": + idx = 2 * (idx + 1) + elif sequence_id == "seq_1": + idx = 4 * (idx + 1) + if start is True: + with requests.post( + url=f"http://localhost:8080/predictions/{model_name}", + data=str(idx).encode(), + ) as response: + s_id = response.headers.get("ts_request_sequence_id") + if sequence_id == "seq_0": + headers_seq_0 = { + "ts_request_sequence_id": s_id, + } + elif sequence_id == "seq_1": + headers_seq_1 = { + "ts_request_sequence_id": s_id, + } + start = False + prediction.append(response.text) + else: + with requests.post( + url=f"http://localhost:8080/predictions/{model_name}", + headers=headers_seq_0 if sequence_id == "seq_0" else headers_seq_1, + data=str(idx).encode(), + ) as response: + prediction.append(response.text) + + print(f"infer_stateful prediction={str(' '.join(prediction))}") + assert str(" ".join(prediction)) == expected + + +def __infer_stateful_end(model_name, sequence_id, expected): + prediction = [] + start = True + end = False + for idx in range(5): + if idx == 4: + end = True + if sequence_id == "seq_0": + idx = 2 * (idx + 1) + elif sequence_id == "seq_1": + idx = 4 * (idx + 1) + if end is True: + idx = 0 + + if start is True: + with requests.post( + url=f"http://localhost:8080/predictions/{model_name}", + data=str(idx).encode(), + ) as response: + s_id = response.headers.get("ts_request_sequence_id") + if sequence_id == "seq_0": + headers_seq_0 = { + "ts_request_sequence_id": s_id, + } + elif sequence_id == "seq_1": + headers_seq_1 = { + "ts_request_sequence_id": s_id, + } + start = False + prediction.append(response.text) + else: + with requests.post( + url=f"http://localhost:8080/predictions/{model_name}", + headers=headers_seq_0 if sequence_id == "seq_0" else headers_seq_1, + data=str(idx).encode(), + ) as response: + prediction.append(response.text) + + print(f"infer_stateful_end prediction={str(' '.join(prediction))}") + assert str(" ".join(prediction)) == expected + + +def __infer_stateful_cancel(model_name, is_cancel, headers, expected): + prediction = [] + if is_cancel: + time.sleep(1) + with requests.post( + url=f"http://localhost:8080/predictions/{model_name}", + headers=headers, + data=str(-1).encode(), + ) as response: + prediction.append(response.text) + print(f"infer_stateful_cancel prediction={str(' '.join(prediction))}") + assert str(" ".join(prediction)) == expected + else: + with requests.post( + url=f"http://localhost:8080/predictions/{model_name}", + headers=headers, + json=JSON_INPUT, + stream=True, + ) as response: + assert response.headers["Transfer-Encoding"] == "chunked" + for chunk in response.iter_content(chunk_size=None): + if chunk: + prediction += [chunk.decode("utf-8")] + + print(f"infer_stateful_cancel prediction={str(' '.join(prediction))}") + assert prediction[0] == expected + assert len(prediction) < 11 diff --git a/ts/protocol/otf_message_handler.py b/ts/protocol/otf_message_handler.py index 95bcec37d9..29de350e15 100644 --- a/ts/protocol/otf_message_handler.py +++ b/ts/protocol/otf_message_handler.py @@ -95,10 +95,10 @@ def create_predict_response( if ts_stream_next is True: context.set_response_header(idx, "ts_stream_next", "true") elif context.stopping_criteria: - ts_stream_next = ( - "false" if context.stopping_criteria[idx](ret[idx]) else "true" - ) - context.set_response_header(idx, "ts_stream_next", ts_stream_next) + is_stop = context.stopping_criteria[idx](ret[idx]) + if is_stop is not None: + ts_stream_next = "false" if is_stop else "true" + context.set_response_header(idx, "ts_stream_next", ts_stream_next) elif "true" == context.get_response_headers(idx).get("ts_stream_next"): context.set_response_header(idx, "ts_stream_next", "false")