Skip to content

Commit 4c96e6f

Browse files
authored
Fix sequence continuous batching close session race condition (#3198)
* add logging to trace jobgroup cleanup * Monitor eventJobGroupIds * Revert "Monitor eventJobGroupIds" This reverts commit 70ef9b0. * Log reset job group Ids * Test adding job to evenJobGroupIds after completing streaming request * Revert "Test adding job to evenJobGroupIds after completing streaming request" This reverts commit ab78a9a. * Force cleanup job group * Repeat close session request to follow through with cleanup * Improve detection of close session * formatJava * test not adding dummy job to closed job group * Revert "test not adding dummy job to closed job group" This reverts commit 51d706a. * Remove debug logging * comments about fix * Avoid duplicate CompletableFutures * Check executor task status using CompletableFuture object * Track available local capacity for a worker * Fix computation of capacity values * Update test to check session cleanup * Update pollQueueTasks key for pollJobGroup task * formatJava
1 parent 688f09e commit 4c96e6f

File tree

2 files changed

+92
-53
lines changed

2 files changed

+92
-53
lines changed

frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java

+48-21
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import java.util.LinkedHashSet;
44
import java.util.LinkedList;
55
import java.util.concurrent.CompletableFuture;
6+
import java.util.concurrent.ConcurrentHashMap;
67
import java.util.concurrent.ExecutionException;
78
import java.util.concurrent.ExecutorService;
89
import java.util.concurrent.Executors;
910
import java.util.concurrent.LinkedBlockingDeque;
1011
import java.util.concurrent.TimeUnit;
1112
import java.util.concurrent.atomic.AtomicBoolean;
13+
import java.util.concurrent.atomic.AtomicInteger;
1214
import org.pytorch.serve.job.Job;
1315
import org.pytorch.serve.job.JobGroup;
1416
import org.pytorch.serve.util.messages.BaseModelRequest;
@@ -34,16 +36,20 @@ public class SequenceBatching extends BatchAggregator {
3436
// A list of jobGroupIds which are added into current batch. These jobGroupIds need to be added
3537
// back to eventJobGroupIds once their jobs are processed by a batch.
3638
protected LinkedList<String> currentJobGroupIds;
37-
private int localCapacity;
39+
private AtomicInteger localCapacity;
3840
private AtomicBoolean running = new AtomicBoolean(true);
41+
// HashMap to track poll queue tasks in the executor queue
42+
private ConcurrentHashMap<String, CompletableFuture<Void>> pollQueueTasks =
43+
new ConcurrentHashMap<String, CompletableFuture<Void>>();
3944

4045
public SequenceBatching(Model model) {
4146
super(model);
47+
this.localCapacity =
48+
new AtomicInteger(Math.max(1, model.getMaxNumSequence() / model.getMinWorkers()));
4249
this.currentJobGroupIds = new LinkedList<>();
43-
this.pollExecutors = Executors.newFixedThreadPool(model.getBatchSize() + 1);
50+
this.pollExecutors = Executors.newFixedThreadPool(localCapacity.get() + 1);
4451
this.jobsQueue = new LinkedBlockingDeque<>();
4552
this.isPollJobGroup = new AtomicBoolean(false);
46-
this.localCapacity = model.getMaxNumSequence() / model.getMinWorkers();
4753
this.eventJobGroupIds = new LinkedBlockingDeque<>();
4854
this.eventJobGroupIds.add("");
4955
this.eventDispatcher = new Thread(new EventDispatcher());
@@ -70,8 +76,9 @@ private void pollJobGroup() throws InterruptedException {
7076

7177
int quota =
7278
Math.min(
73-
this.localCapacity - jobsQueue.size(),
74-
model.getPendingJobGroups().size() / model.getMaxWorkers());
79+
this.localCapacity.get(),
80+
Math.max(
81+
1, model.getPendingJobGroups().size() / model.getMaxWorkers()));
7582
if (quota > 0 && model.getPendingJobGroups().size() > 0) {
7683
model.getPendingJobGroups().drainTo(tmpJobGroups, quota);
7784
}
@@ -120,6 +127,8 @@ private void cleanJobGroup(String jobGroupId) {
120127
logger.debug("Clean jobGroup: {}", jobGroupId);
121128
if (jobGroupId != null) {
122129
model.removeJobGroup(jobGroupId);
130+
pollQueueTasks.remove(jobGroupId);
131+
localCapacity.incrementAndGet();
123132
}
124133
}
125134

@@ -176,6 +185,7 @@ public void shutdownExecutors() {
176185

177186
private void addJobGroup(String jobGroupId) {
178187
if (jobGroupId != null) {
188+
localCapacity.decrementAndGet();
179189
eventJobGroupIds.add(jobGroupId);
180190
}
181191
}
@@ -192,22 +202,39 @@ public void run() {
192202
String jobGroupId =
193203
eventJobGroupIds.poll(model.getMaxBatchDelay(), TimeUnit.MILLISECONDS);
194204
if (jobGroupId == null || jobGroupId.isEmpty()) {
195-
CompletableFuture.runAsync(
196-
() -> {
197-
try {
198-
pollJobGroup();
199-
} catch (InterruptedException e) {
200-
logger.error("Failed to poll a job group", e);
201-
}
202-
},
203-
pollExecutors);
205+
// Skip fetching new job groups when no capacity is available
206+
if (localCapacity.get() <= 0) {
207+
continue;
208+
}
209+
// Avoid duplicate poll tasks in the executor queue
210+
if (pollQueueTasks.containsKey("pollJobGroup")
211+
&& !pollQueueTasks.get("pollJobGroup").isDone()) {
212+
continue;
213+
}
214+
CompletableFuture<Void> pollTask =
215+
CompletableFuture.runAsync(
216+
() -> {
217+
try {
218+
pollJobGroup();
219+
} catch (InterruptedException e) {
220+
logger.error("Failed to poll a job group", e);
221+
}
222+
},
223+
pollExecutors);
224+
pollQueueTasks.put("pollJobGroup", pollTask);
204225
} else {
205-
206-
CompletableFuture.runAsync(
207-
() -> {
208-
pollJobFromJobGroup(jobGroupId);
209-
},
210-
pollExecutors);
226+
// Avoid duplicate poll tasks in the executor queue
227+
if (pollQueueTasks.containsKey(jobGroupId)
228+
&& !pollQueueTasks.get(jobGroupId).isDone()) {
229+
continue;
230+
}
231+
CompletableFuture<Void> pollTask =
232+
CompletableFuture.runAsync(
233+
() -> {
234+
pollJobFromJobGroup(jobGroupId);
235+
},
236+
pollExecutors);
237+
pollQueueTasks.put(jobGroupId, pollTask);
211238
}
212239
} catch (InterruptedException e) {
213240
if (running.get()) {
@@ -224,7 +251,7 @@ private void pollJobFromJobGroup(String jobGroupId) {
224251
if (!jobGroup.isFinished()) {
225252
job = jobGroup.pollJob(model.getSequenceMaxIdleMSec());
226253
}
227-
if (job == null) {
254+
if (job == null || jobGroup.isFinished()) {
228255
// JobGroup expired, clean it.
229256
cleanJobGroup(jobGroupId);
230257
// intent to add new job groups.

test/pytest/test_example_stateful_sequence_continuous_batching_http.py

+44-32
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
maxWorkers: 2
2323
batchSize: 1
2424
maxNumSequence: 2
25-
sequenceMaxIdleMSec: 5000
25+
sequenceMaxIdleMSec: 60000
2626
maxSequenceJobQueueSize: 10
2727
sequenceBatching: true
2828
continuousBatching: true
@@ -219,39 +219,51 @@ def test_infer_stateful_cancel(mar_file_path, model_store):
219219

220220
try:
221221
test_utils.reg_resp = test_utils.register_model_with_params(params)
222-
with requests.post(
223-
url=f"http://localhost:8080/predictions/{model_name}",
224-
data=str(2).encode(),
225-
) as response:
226-
s_id = response.headers.get("ts_request_sequence_id")
227-
headers = {
228-
"ts_request_sequence_id": s_id,
229-
}
230222

231-
t0 = threading.Thread(
232-
target=__infer_stateful_cancel,
233-
args=(
234-
model_name,
235-
False,
236-
headers,
237-
"5",
238-
),
239-
)
240-
t1 = threading.Thread(
241-
target=__infer_stateful_cancel,
242-
args=(
243-
model_name,
244-
True,
245-
headers,
246-
"-1",
247-
),
248-
)
249-
250-
t0.start()
251-
t1.start()
223+
# Open and close sesions multiple times(>maxNumSequence) to test session clean up after stream response
224+
for _ in range(4):
225+
with requests.post(
226+
url=f"http://localhost:8080/predictions/{model_name}",
227+
data=str(2).encode(),
228+
) as response:
229+
s_id = response.headers.get("ts_request_sequence_id")
230+
headers = {
231+
"ts_request_sequence_id": s_id,
232+
}
233+
234+
t0 = threading.Thread(
235+
target=__infer_stateful_cancel,
236+
args=(
237+
model_name,
238+
False,
239+
headers,
240+
"5",
241+
),
242+
)
243+
t1 = threading.Thread(
244+
target=__infer_stateful_cancel,
245+
args=(
246+
model_name,
247+
True,
248+
headers,
249+
"-1",
250+
),
251+
)
252+
253+
t0.start()
254+
t1.start()
255+
256+
t0.join()
257+
t1.join()
258+
259+
# Close session after cancellation request to free up session capacity
260+
with requests.post(
261+
url=f"http://localhost:8080/predictions/{model_name}",
262+
headers=headers,
263+
data=str(0).encode(),
264+
) as response:
265+
assert response.status_code == 200
252266

253-
t0.join()
254-
t1.join()
255267
finally:
256268
test_utils.unregister_model(model_name)
257269

0 commit comments

Comments
 (0)