Skip to content

Commit 4a57c67

Browse files
lxningmreso
andauthoredMay 15, 2024··
Sequence batch via http (#3142)
* init http * http support stateful * fix typo * merge master * add test * add test * fix test * fix test * address comments * add more test * check running * update test * cross test --------- Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
1 parent ff3fd0a commit 4a57c67

File tree

17 files changed

+352
-56
lines changed

17 files changed

+352
-56
lines changed
 

‎examples/stateful/Readme.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ maxWorkers: 2
8282
batchSize: 4
8383
sequenceMaxIdleMSec: 60000
8484
maxSequenceJobQueueSize: 10
85+
sequenceBatching: true
8586

8687
handler:
8788
cache:
@@ -122,8 +123,14 @@ cd -
122123
torchserve --ncs --start --model-store models --model stateful.mar --ts-config config.properties
123124
```
124125

125-
* Run sequence inference
126+
* Run sequence inference via GRPC client
126127
```bash
127128
cd ../../
128129
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
129130
```
131+
132+
* Run sequence inference via HTTP
133+
```bash
134+
cd ../../
135+
curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt
136+
```

‎examples/stateful/model-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ batchSize: 4
44
maxNumSequence: 4
55
sequenceMaxIdleMSec: 10
66
maxSequenceJobQueueSize: 10
7+
sequenceBatching: true
78

89
handler:
910
cache:

‎frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java

+19-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public class ModelConfig {
5656
private boolean useJobTicket;
5757
/**
5858
* the max idle in milliseconds of a sequence inference request of this stateful model. The
59-
* default value is 0 (ie. this is not a stateful model.)
59+
* default value is 0.
6060
*/
6161
private long sequenceMaxIdleMSec;
6262
/**
@@ -73,6 +73,8 @@ public class ModelConfig {
7373
* loading and inference.
7474
*/
7575
private boolean useVenv;
76+
/** sequenceBatching is a flag to enable https://github.com/pytorch/serve/issues/2743 */
77+
private boolean sequenceBatching;
7678

7779
public static ModelConfig build(Map<String, Object> yamlMap) {
7880
ModelConfig modelConfig = new ModelConfig();
@@ -212,6 +214,14 @@ public static ModelConfig build(Map<String, Object> yamlMap) {
212214
v);
213215
}
214216
break;
217+
case "sequenceBatching":
218+
if (v instanceof Boolean) {
219+
modelConfig.setSequenceBatching((boolean) v);
220+
} else {
221+
logger.warn(
222+
"Invalid sequenceBatching: {}, should be true or false", v);
223+
}
224+
break;
215225
case "useVenv":
216226
if (v instanceof Boolean) {
217227
modelConfig.setUseVenv((boolean) v);
@@ -383,6 +393,14 @@ public void setContinuousBatching(boolean continuousBatching) {
383393
this.continuousBatching = continuousBatching;
384394
}
385395

396+
public boolean isSequenceBatching() {
397+
return sequenceBatching;
398+
}
399+
400+
public void setSequenceBatching(boolean sequenceBatching) {
401+
this.sequenceBatching = sequenceBatching;
402+
}
403+
386404
public int getMaxNumSequence() {
387405
return maxNumSequence;
388406
}

‎frontend/server/src/main/java/org/pytorch/serve/grpcimpl/InferenceImpl.java

+28-4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
public class InferenceImpl extends InferenceAPIsServiceImplBase {
4141
private static final Logger logger = LoggerFactory.getLogger(InferenceImpl.class);
42+
private static final ByteString strFalse = ByteString.copyFromUtf8("false");
4243

4344
@Override
4445
public void ping(Empty request, StreamObserver<TorchServeHealthResponse> responseObserver) {
@@ -102,9 +103,16 @@ public StreamObserver<PredictionsRequest> streamPredictions2(
102103

103104
@Override
104105
public void onNext(PredictionsRequest value) {
105-
String sequenceId = value.getSequenceId();
106-
107-
if ("".equals(sequenceId)) {
106+
boolean not_has_seq_id = "".equals(value.getSequenceId());
107+
boolean has_seq_in_header =
108+
!Boolean.parseBoolean(
109+
value.getInputOrDefault(
110+
ConfigManager.getInstance()
111+
.getTsHeaderKeySequenceStart(),
112+
strFalse)
113+
.toString()
114+
.toLowerCase());
115+
if (not_has_seq_id && has_seq_in_header) {
108116
BadRequestException e =
109117
new BadRequestException("Parameter sequenceId is required.");
110118
sendErrorResponse(
@@ -219,7 +227,23 @@ private void prediction(
219227
new InputParameter(entry.getKey(), entry.getValue().toByteArray()));
220228
}
221229
if (workerCmd == WorkerCommands.STREAMPREDICT2) {
222-
inputData.setSequenceId(request.getSequenceId());
230+
String sequenceId = request.getSequenceId();
231+
if ("".equals(sequenceId)) {
232+
sequenceId = String.format("ts-%s", UUID.randomUUID());
233+
inputData.updateHeaders(
234+
ConfigManager.getInstance().getTsHeaderKeySequenceStart(), "true");
235+
}
236+
inputData.updateHeaders(
237+
ConfigManager.getInstance().getTsHeaderKeySequenceId(), sequenceId);
238+
if (!Boolean.parseBoolean(
239+
request.getInputOrDefault(
240+
ConfigManager.getInstance().getTsHeaderKeySequenceEnd(),
241+
strFalse)
242+
.toString()
243+
.toLowerCase())) {
244+
inputData.updateHeaders(
245+
ConfigManager.getInstance().getTsHeaderKeySequenceEnd(), "true");
246+
}
223247
}
224248

225249
IMetric inferenceRequestsTotalMetric =

‎frontend/server/src/main/java/org/pytorch/serve/job/Job.java

-12
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,6 @@ public Job(String modelName, String version, WorkerCommands cmd, RequestInput in
2020
this.modelVersion = version;
2121
begin = System.nanoTime();
2222
scheduled = begin;
23-
24-
switch (cmd) {
25-
case STREAMPREDICT:
26-
input.updateHeaders(RequestInput.TS_STREAM_NEXT, "true");
27-
break;
28-
case STREAMPREDICT2:
29-
input.updateHeaders(RequestInput.TS_STREAM_NEXT, "true");
30-
input.updateHeaders(RequestInput.TS_REQUEST_SEQUENCE_ID, input.getSequenceId());
31-
break;
32-
default:
33-
break;
34-
}
3523
}
3624

3725
public String getJobId() {

‎frontend/server/src/main/java/org/pytorch/serve/job/JobGroup.java

+13
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,23 @@ public class JobGroup {
1010
String groupId;
1111
LinkedBlockingDeque<Job> jobs;
1212
int maxJobQueueSize;
13+
boolean finished;
1314

1415
public JobGroup(String groupId, int maxJobQueueSize) {
1516
this.groupId = groupId;
1617
this.maxJobQueueSize = maxJobQueueSize;
1718
this.jobs = new LinkedBlockingDeque<>(maxJobQueueSize);
19+
this.finished = false;
1820
}
1921

2022
public boolean appendJob(Job job) {
2123
return jobs.offer(job);
2224
}
2325

2426
public Job pollJob(long timeout) {
27+
if (finished) {
28+
return null;
29+
}
2530
try {
2631
return jobs.poll(timeout, TimeUnit.MILLISECONDS);
2732
} catch (InterruptedException e) {
@@ -33,4 +38,12 @@ public Job pollJob(long timeout) {
3338
public String getGroupId() {
3439
return groupId;
3540
}
41+
42+
public void setFinished(boolean sequenceEnd) {
43+
this.finished = sequenceEnd;
44+
}
45+
46+
public boolean isFinished() {
47+
return this.finished;
48+
}
3649
}

‎frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java

+13-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.Collections;
1010
import java.util.List;
1111
import java.util.Map;
12+
import java.util.UUID;
1213
import java.util.concurrent.CompletableFuture;
1314
import java.util.concurrent.ExecutionException;
1415
import java.util.function.Function;
@@ -412,7 +413,7 @@ private static DescribeModelResponse createModelResponse(
412413
resp.setContinuousBatching(model.isContinuousBatching());
413414
resp.setUseJobTicket(model.isUseJobTicket());
414415
resp.setUseVenv(model.isUseVenv());
415-
resp.setStateful(model.isStateful());
416+
resp.setStateful(model.isSequenceBatching());
416417
resp.setSequenceMaxIdleMSec(model.getSequenceMaxIdleMSec());
417418
resp.setMaxNumSequence(model.getMaxNumSequence());
418419
resp.setMaxSequenceJobQueueSize(model.getMaxSequenceJobQueueSize());
@@ -442,6 +443,17 @@ private static DescribeModelResponse createModelResponse(
442443
public static RestJob addRESTInferenceJob(
443444
ChannelHandlerContext ctx, String modelName, String version, RequestInput input)
444445
throws ModelNotFoundException, ModelVersionNotFoundException {
446+
String sequenceStart;
447+
if ((sequenceStart =
448+
input.getHeaders()
449+
.get(ConfigManager.getInstance().getTsHeaderKeySequenceStart()))
450+
!= null) {
451+
if (Boolean.parseBoolean(sequenceStart.toLowerCase())) {
452+
String sequenceId = String.format("ts-%s", UUID.randomUUID());
453+
input.updateHeaders(
454+
ConfigManager.getInstance().getTsHeaderKeySequenceId(), sequenceId);
455+
}
456+
}
445457
RestJob job = new RestJob(ctx, modelName, version, WorkerCommands.PREDICT, input);
446458
if (!ModelManager.getInstance().addJob(job)) {
447459
String responseMessage = getStreamingInferenceErrorResponseMessage(modelName, version);

‎frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java

+41
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ public final class ConfigManager {
120120
private static final String TS_CPP_LOG_CONFIG = "cpp_log_config";
121121
private static final String TS_OPEN_INFERENCE_PROTOCOL = "ts_open_inference_protocol";
122122
private static final String TS_TOKEN_EXPIRATION_TIME_MIN = "token_expiration_min";
123+
private static final String TS_HEADER_KEY_SEQUENCE_ID = "ts_header_key_sequence_id";
124+
private static final String TS_HEADER_KEY_SEQUENCE_START = "ts_header_key_sequence_start";
125+
private static final String TS_HEADER_KEY_SEQUENCE_END = "ts_header_key_sequence_end";
123126

124127
// Configuration which are not documented or enabled through environment variables
125128
private static final String USE_NATIVE_IO = "use_native_io";
@@ -145,6 +148,10 @@ public final class ConfigManager {
145148

146149
public static final String PYTHON_EXECUTABLE = "python";
147150

151+
public static final String DEFAULT_REQUEST_SEQUENCE_ID = "ts_request_sequence_id";
152+
public static final String DEFAULT_REQUEST_SEQUENCE_START = "ts_request_sequence_start";
153+
public static final String DEFAULT_REQUEST_SEQUENCE_END = "ts_request_sequence_end";
154+
148155
public static final Pattern ADDRESS_PATTERN =
149156
Pattern.compile(
150157
"((https|http)://([^:^/]+)(:([0-9]+))?)|(unix:(/.*))",
@@ -161,6 +168,10 @@ public final class ConfigManager {
161168
private Map<String, Map<String, JsonObject>> modelConfig = new HashMap<>();
162169
private String torchrunLogDir;
163170
private boolean telemetryEnabled;
171+
private String headerKeySequenceId;
172+
private String headerKeySequenceStart;
173+
private String headerKeySequenceEnd;
174+
164175
private Logger logger = LoggerFactory.getLogger(ConfigManager.class);
165176

166177
private ConfigManager(Arguments args) throws IOException {
@@ -272,6 +283,9 @@ private ConfigManager(Arguments args) throws IOException {
272283
}
273284

274285
setModelConfig();
286+
setTsHeaderKeySequenceId();
287+
setTsHeaderKeySequenceStart();
288+
setTsHeaderKeySequenceEnd();
275289

276290
// Issue warnining about URLs that can be accessed when loading models
277291
if (prop.getProperty(TS_ALLOWED_URLS, DEFAULT_TS_ALLOWED_URLS) == DEFAULT_TS_ALLOWED_URLS) {
@@ -960,6 +974,33 @@ public Double getTimeToExpiration() {
960974
return 0.0;
961975
}
962976

977+
public String getTsHeaderKeySequenceId() {
978+
return this.headerKeySequenceId;
979+
}
980+
981+
public void setTsHeaderKeySequenceId() {
982+
this.headerKeySequenceId =
983+
prop.getProperty(TS_HEADER_KEY_SEQUENCE_ID, DEFAULT_REQUEST_SEQUENCE_ID);
984+
}
985+
986+
public String getTsHeaderKeySequenceStart() {
987+
return this.headerKeySequenceStart;
988+
}
989+
990+
public void setTsHeaderKeySequenceStart() {
991+
this.headerKeySequenceStart =
992+
prop.getProperty(TS_HEADER_KEY_SEQUENCE_START, DEFAULT_REQUEST_SEQUENCE_START);
993+
}
994+
995+
public String getTsHeaderKeySequenceEnd() {
996+
return this.headerKeySequenceEnd;
997+
}
998+
999+
public void setTsHeaderKeySequenceEnd() {
1000+
this.headerKeySequenceEnd =
1001+
prop.getProperty(TS_HEADER_KEY_SEQUENCE_END, DEFAULT_REQUEST_SEQUENCE_END);
1002+
}
1003+
9631004
public boolean isSSLEnabled(ConnectorType connectorType) {
9641005
String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080");
9651006
switch (connectorType) {

‎frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import java.util.HashMap;
66
import java.util.List;
77
import java.util.Map;
8+
import org.pytorch.serve.util.ConfigManager;
89

910
public class RequestInput {
1011
public static final String TS_STREAM_NEXT = "ts_stream_next";
11-
public static final String TS_REQUEST_SEQUENCE_ID = "ts_request_sequence_id";
1212
private String requestId;
1313
private String sequenceId;
1414
private Map<String, String> headers;
@@ -75,6 +75,11 @@ public void setClientExpireTS(long clientTimeoutInMills) {
7575
}
7676

7777
public String getSequenceId() {
78+
if (sequenceId == null) {
79+
sequenceId =
80+
headers.getOrDefault(
81+
ConfigManager.getInstance().getTsHeaderKeySequenceId(), null);
82+
}
7883
return sequenceId;
7984
}
8085

‎frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.pytorch.serve.util.messages.ModelWorkerResponse;
1111
import org.pytorch.serve.util.messages.Predictions;
1212
import org.pytorch.serve.util.messages.RequestInput;
13-
import org.pytorch.serve.util.messages.WorkerCommands;
1413
import org.slf4j.Logger;
1514
import org.slf4j.LoggerFactory;
1615

@@ -56,10 +55,7 @@ public BaseModelRequest getRequest(String threadName, WorkerState state)
5655
}
5756
return new ModelLoadModelRequest(model, gpuId);
5857
} else {
59-
if (j.getCmd() == WorkerCommands.STREAMPREDICT
60-
|| j.getCmd() == WorkerCommands.STREAMPREDICT2) {
61-
req.setCommand(j.getCmd());
62-
}
58+
req.setCommand(j.getCmd());
6359
j.setScheduled();
6460
req.addRequest(j.getPayload());
6561
}
@@ -190,4 +186,8 @@ public void pollBatch(String threadName, WorkerState state)
190186
model.pollBatch(
191187
threadName, (state == WorkerState.WORKER_MODEL_LOADED) ? 0 : Long.MAX_VALUE, jobs);
192188
}
189+
190+
public void shutdown() {
191+
return;
192+
}
193193
}

‎frontend/server/src/main/java/org/pytorch/serve/wlm/ContinuousBatching.java

+1-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import org.pytorch.serve.util.messages.ModelWorkerResponse;
1010
import org.pytorch.serve.util.messages.Predictions;
1111
import org.pytorch.serve.util.messages.RequestInput;
12-
import org.pytorch.serve.util.messages.WorkerCommands;
1312
import org.slf4j.Logger;
1413
import org.slf4j.LoggerFactory;
1514

@@ -48,9 +47,7 @@ public BaseModelRequest getRequest(String threadName, WorkerState state)
4847
}
4948
return new ModelLoadModelRequest(model, gpuId);
5049
} else {
51-
if (j.getCmd() == WorkerCommands.STREAMPREDICT) {
52-
req.setCommand(WorkerCommands.STREAMPREDICT);
53-
}
50+
req.setCommand(j.getCmd());
5451
j.setScheduled();
5552
req.addRequest(j.getPayload());
5653
}

‎frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java

+16-16
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,14 @@ public class Model {
8383
private boolean useJobTicket;
8484
private AtomicInteger numJobTickets;
8585
private boolean continuousBatching;
86+
private boolean sequenceBatch;
8687
private boolean useVenv;
8788

8889
public Model(ModelArchive modelArchive, int queueSize) {
8990
this.modelArchive = modelArchive;
9091
if (modelArchive != null && modelArchive.getModelConfig() != null) {
9192
continuousBatching = modelArchive.getModelConfig().isContinuousBatching();
93+
sequenceBatch = modelArchive.getModelConfig().isSequenceBatching();
9294
useVenv = modelArchive.getModelConfig().getUseVenv();
9395
if (modelArchive.getModelConfig().getParallelLevel() > 0
9496
&& modelArchive.getModelConfig().getParallelType()
@@ -131,10 +133,11 @@ public Model(ModelArchive modelArchive, int queueSize) {
131133
Math.max(
132134
modelArchive.getModelConfig().getMaxNumSequence(),
133135
batchSize * maxWorkers);
134-
jobGroups = new ConcurrentHashMap<>(maxNumSequence);
135-
pendingJobGroups = new LinkedBlockingDeque<>(maxNumSequence);
136-
jobGroupLock = new ReentrantLock();
137-
stateful = true;
136+
if (sequenceBatch) {
137+
jobGroups = new ConcurrentHashMap<>(maxNumSequence);
138+
pendingJobGroups = new LinkedBlockingDeque<>(maxNumSequence);
139+
jobGroupLock = new ReentrantLock();
140+
}
138141
}
139142
} else {
140143
batchSize = 1;
@@ -288,7 +291,7 @@ public boolean addJob(Job job) {
288291
logger.info("There are no job tickets available");
289292
return false;
290293
}
291-
if (job.getGroupId() != null) {
294+
if (sequenceBatch && job.getGroupId() != null) {
292295
return addJobInGroup(job);
293296
}
294297
return jobsDb.get(DEFAULT_DATA_QUEUE).offer(job);
@@ -460,9 +463,8 @@ public void pollBatch(String threadId, long waitTime, Map<String, Job> jobsRepo)
460463
logger.trace("get first job: {}", Objects.requireNonNull(j).getJobId());
461464

462465
jobsRepo.put(j.getJobId(), j);
463-
// batch size always is 1 for describe request job and stream prediction request job
464-
if (j.getCmd() == WorkerCommands.DESCRIBE
465-
|| j.getCmd() == WorkerCommands.STREAMPREDICT) {
466+
// batch size always is 1 for describe request job
467+
if (j.getCmd() == WorkerCommands.DESCRIBE) {
466468
return;
467469
}
468470
long begin = System.currentTimeMillis();
@@ -472,10 +474,8 @@ public void pollBatch(String threadId, long waitTime, Map<String, Job> jobsRepo)
472474
break;
473475
}
474476
long end = System.currentTimeMillis();
475-
// job batch size always is 1 when request is
476-
// describe or stream prediction
477-
if (j.getCmd() == WorkerCommands.DESCRIBE
478-
|| j.getCmd() == WorkerCommands.STREAMPREDICT) {
477+
// job batch size always is 1 when request is describe
478+
if (j.getCmd() == WorkerCommands.DESCRIBE) {
479479
// Add the job back into the jobsQueue
480480
jobsQueue.addFirst(j);
481481
break;
@@ -610,10 +610,6 @@ public void setSequenceMaxIdleMSec(long sequenceMaxIdleMSec) {
610610
this.sequenceMaxIdleMSec = sequenceMaxIdleMSec;
611611
}
612612

613-
public boolean isStateful() {
614-
return stateful;
615-
}
616-
617613
public int getMaxSequenceJobQueueSize() {
618614
return maxSequenceJobQueueSize;
619615
}
@@ -638,6 +634,10 @@ public boolean isContinuousBatching() {
638634
return continuousBatching;
639635
}
640636

637+
public boolean isSequenceBatching() {
638+
return sequenceBatch;
639+
}
640+
641641
public boolean isUseVenv() {
642642
if (getRuntimeType() == Manifest.RuntimeType.PYTHON) {
643643
return useVenv;

‎frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatchAggregator.java

+25-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.concurrent.atomic.AtomicBoolean;
1212
import org.pytorch.serve.job.Job;
1313
import org.pytorch.serve.job.JobGroup;
14+
import org.pytorch.serve.util.ConfigManager;
1415
import org.pytorch.serve.util.messages.BaseModelRequest;
1516
import org.pytorch.serve.util.messages.ModelWorkerResponse;
1617
import org.slf4j.Logger;
@@ -35,6 +36,7 @@ public class SequenceBatchAggregator extends BatchAggregator {
3536
// back to eventJobGroupIds once their jobs are processed by a batch.
3637
private LinkedList<String> currentJobGroupIds;
3738
private int localCapacity;
39+
private AtomicBoolean running = new AtomicBoolean(true);
3840

3941
public SequenceBatchAggregator(Model model) {
4042
super(model);
@@ -161,6 +163,13 @@ public void cleanJobs() {
161163
}
162164
}
163165

166+
@Override
167+
public void shutdown() {
168+
this.setRunning(false);
169+
this.shutdownExecutors();
170+
this.stopEventDispatcher();
171+
}
172+
164173
public void shutdownExecutors() {
165174
this.pollExecutors.shutdown();
166175
}
@@ -171,10 +180,14 @@ private void addJobGroup(String jobGroupId) {
171180
}
172181
}
173182

183+
public void setRunning(boolean running) {
184+
this.running.set(running);
185+
}
186+
174187
class EventDispatcher implements Runnable {
175188
@Override
176189
public void run() {
177-
while (true) {
190+
while (running.get()) {
178191
try {
179192
String jobGroupId =
180193
eventJobGroupIds.poll(model.getMaxBatchDelay(), TimeUnit.MILLISECONDS);
@@ -197,7 +210,9 @@ public void run() {
197210
pollExecutors);
198211
}
199212
} catch (InterruptedException e) {
200-
logger.error("EventDispatcher failed to get jobGroup", e);
213+
if (running.get()) {
214+
logger.error("EventDispatcher failed to get jobGroup", e);
215+
}
201216
}
202217
}
203218
}
@@ -212,6 +227,14 @@ private void pollJobFromJobGroup(String jobGroupId) {
212227
// intent to add new job groups.
213228
eventJobGroupIds.add("");
214229
} else {
230+
if (Boolean.parseBoolean(
231+
job.getPayload()
232+
.getHeaders()
233+
.getOrDefault(
234+
ConfigManager.getInstance().getTsHeaderKeySequenceEnd(),
235+
"false"))) {
236+
jobGroup.setFinished(true);
237+
}
215238
jobsQueue.add(job);
216239
}
217240
}

‎frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ private void addThreads(
229229

230230
BatchAggregator aggregator;
231231

232-
if (model.isStateful()) {
232+
if (model.isSequenceBatching()) {
233233
aggregator = new SequenceBatchAggregator(model);
234234
} else if (model.isContinuousBatching()) {
235235
aggregator = new ContinuousBatching(model);

‎frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java

+2-6
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,8 @@ public void run() {
248248

249249
switch (req.getCommand()) {
250250
case PREDICT:
251-
model.resetFailedInfReqs();
252-
break;
253251
case STREAMPREDICT:
252+
case STREAMPREDICT2:
254253
model.resetFailedInfReqs();
255254
break;
256255
case LOAD:
@@ -471,6 +470,7 @@ public int getPid() {
471470

472471
public void shutdown() {
473472
running.set(false);
473+
aggregator.shutdown();
474474
setState(WorkerState.WORKER_SCALED_DOWN, HttpURLConnection.HTTP_OK);
475475
for (int i = 0;
476476
backendChannel.size() > 0
@@ -489,10 +489,6 @@ public void shutdown() {
489489

490490
model.removeJobQueue(workerId);
491491
}
492-
if (aggregator instanceof SequenceBatchAggregator) {
493-
((SequenceBatchAggregator) aggregator).shutdownExecutors();
494-
((SequenceBatchAggregator) aggregator).stopEventDispatcher();
495-
}
496492
}
497493

498494
private String getWorkerName() {
+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import shutil
2+
import sys
3+
import threading
4+
from pathlib import Path
5+
6+
import pytest
7+
import requests
8+
import test_utils
9+
from model_archiver.model_archiver_config import ModelArchiverConfig
10+
11+
CURR_FILE_PATH = Path(__file__).parent
12+
STATEFUL_PATH = CURR_FILE_PATH.parents[1] / "examples" / "stateful"
13+
CONFIG_PROPERTIES_PATH = CURR_FILE_PATH.parents[1] / "test" / "config_ts.properties"
14+
15+
YAML_CONFIG = f"""
16+
# TorchServe frontend parameters
17+
minWorkers: 2
18+
maxWorkers: 2
19+
batchSize: 4
20+
maxNumSequence: 4
21+
sequenceMaxIdleMSec: 5000
22+
maxSequenceJobQueueSize: 10
23+
sequenceBatching: true
24+
25+
handler:
26+
cache:
27+
capacity: 4
28+
"""
29+
30+
PROMPTS = [
31+
{
32+
"prompt": "A robot may not injure a human being",
33+
"max_new_tokens": 50,
34+
"temperature": 0.8,
35+
"logprobs": 1,
36+
"prompt_logprobs": 1,
37+
"max_tokens": 128,
38+
"adapter": "adapter_1",
39+
},
40+
]
41+
42+
43+
@pytest.fixture
44+
def add_paths():
45+
sys.path.append(STATEFUL_PATH.as_posix())
46+
yield
47+
sys.path.pop()
48+
49+
50+
@pytest.fixture(scope="module")
51+
def model_name():
52+
yield "stateful"
53+
54+
55+
@pytest.fixture(scope="module")
56+
def work_dir(tmp_path_factory, model_name):
57+
return tmp_path_factory.mktemp(model_name)
58+
59+
60+
@pytest.fixture(scope="module", name="mar_file_path")
61+
def create_mar_file(work_dir, model_archiver, model_name, request):
62+
mar_file_path = Path(work_dir).joinpath(model_name)
63+
64+
model_config_yaml = Path(work_dir) / "model-config.yaml"
65+
model_config_yaml.write_text(YAML_CONFIG)
66+
67+
config = ModelArchiverConfig(
68+
model_name=model_name,
69+
version="1.0",
70+
handler=(STATEFUL_PATH / "stateful_handler.py").as_posix(),
71+
serialized_file=(STATEFUL_PATH / "model_cnn.pt").as_posix(),
72+
model_file=(STATEFUL_PATH / "model.py").as_posix(),
73+
export_path=work_dir,
74+
requirements_file=(STATEFUL_PATH / "requirements.txt").as_posix(),
75+
runtime="python",
76+
force=False,
77+
config_file=model_config_yaml.as_posix(),
78+
archive_format="no-archive",
79+
)
80+
81+
model_archiver.generate_model_archive(config)
82+
83+
assert mar_file_path.exists()
84+
85+
yield mar_file_path.as_posix()
86+
87+
# Clean up files
88+
shutil.rmtree(mar_file_path)
89+
90+
91+
def test_stateful_mar(mar_file_path, model_store):
92+
"""
93+
Register the model in torchserve
94+
"""
95+
96+
file_name = Path(mar_file_path).name
97+
98+
model_name = Path(file_name).stem
99+
100+
shutil.copytree(mar_file_path, Path(model_store) / model_name)
101+
102+
params = (
103+
("model_name", model_name),
104+
("url", Path(model_store) / model_name),
105+
("initial_workers", "2"),
106+
("synchronous", "true"),
107+
)
108+
109+
test_utils.start_torchserve(
110+
model_store=model_store, snapshot_file=CONFIG_PROPERTIES_PATH, gen_mar=False
111+
)
112+
113+
try:
114+
test_utils.reg_resp = test_utils.register_model_with_params(params)
115+
116+
t0 = threading.Thread(
117+
target=__infer_stateful,
118+
args=(
119+
model_name,
120+
"seq_0",
121+
"1 4 9 16 25",
122+
),
123+
)
124+
t1 = threading.Thread(
125+
target=__infer_stateful,
126+
args=(
127+
model_name,
128+
"seq_1",
129+
"2 6 12 20 30",
130+
),
131+
)
132+
133+
t0.start()
134+
t1.start()
135+
136+
t0.join()
137+
t1.join()
138+
finally:
139+
test_utils.unregister_model(model_name)
140+
141+
# Clean up files
142+
shutil.rmtree(Path(model_store) / model_name)
143+
144+
145+
def __infer_stateful(model_name, sequence_id, expected):
146+
headers = {
147+
"ts_request_sequence_id": sequence_id,
148+
}
149+
prediction = []
150+
for idx in range(5):
151+
if sequence_id == "seq_0":
152+
idx = 2 * idx
153+
elif sequence_id == "seq_1":
154+
idx = 2 * idx + 1
155+
response = requests.post(
156+
url=f"http://localhost:8080/predictions/{model_name}",
157+
headers=headers,
158+
data=str(idx + 1).encode(),
159+
)
160+
prediction.append(response.text)
161+
162+
assert str(" ".join(prediction)) == expected

‎ts/context.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Context object of incoming request
33
"""
4-
4+
import os
55
from typing import Dict, Optional, Tuple
66

77

@@ -40,6 +40,15 @@ def __init__(
4040
self.metrics = metrics
4141
self.model_yaml_config = model_yaml_config
4242
self.stopping_criteria = None
43+
self.header_key_sequence_id = os.getenv(
44+
"TS_REQUEST_SEQUENCE_ID", "ts_request_sequence_id"
45+
)
46+
self.header_key_sequence_start = os.getenv(
47+
"TS_REQUEST_SEQUENCE_START", "ts_request_sequence_start"
48+
)
49+
self.header_key_sequence_end = os.getenv(
50+
"TS_REQUEST_SEQUENCE_END", "ts_request_sequence_end"
51+
)
4352

4453
@property
4554
def system_properties(self):
@@ -121,7 +130,7 @@ def __eq__(self, other: object) -> bool:
121130

122131
def get_sequence_id(self, idx: int) -> str:
123132
return self._request_processor[idx].get_request_property(
124-
"ts_request_sequence_id"
133+
self.header_key_sequence_id
125134
)
126135

127136

0 commit comments

Comments
 (0)
Please sign in to comment.