3
3
import java .util .LinkedHashSet ;
4
4
import java .util .LinkedList ;
5
5
import java .util .concurrent .CompletableFuture ;
6
+ import java .util .concurrent .ConcurrentHashMap ;
6
7
import java .util .concurrent .ExecutionException ;
7
8
import java .util .concurrent .ExecutorService ;
8
9
import java .util .concurrent .Executors ;
9
10
import java .util .concurrent .LinkedBlockingDeque ;
10
11
import java .util .concurrent .TimeUnit ;
11
12
import java .util .concurrent .atomic .AtomicBoolean ;
13
+ import java .util .concurrent .atomic .AtomicInteger ;
12
14
import org .pytorch .serve .job .Job ;
13
15
import org .pytorch .serve .job .JobGroup ;
14
16
import org .pytorch .serve .util .messages .BaseModelRequest ;
@@ -34,16 +36,20 @@ public class SequenceBatching extends BatchAggregator {
34
36
// A list of jobGroupIds which are added into current batch. These jobGroupIds need to be added
35
37
// back to eventJobGroupIds once their jobs are processed by a batch.
36
38
protected LinkedList <String > currentJobGroupIds ;
37
- private int localCapacity ;
39
+ private AtomicInteger localCapacity ;
38
40
private AtomicBoolean running = new AtomicBoolean (true );
41
+ // HashMap to track poll queue tasks in the executor queue
42
+ private ConcurrentHashMap <String , CompletableFuture <Void >> pollQueueTasks =
43
+ new ConcurrentHashMap <String , CompletableFuture <Void >>();
39
44
40
45
public SequenceBatching (Model model ) {
41
46
super (model );
47
+ this .localCapacity =
48
+ new AtomicInteger (Math .max (1 , model .getMaxNumSequence () / model .getMinWorkers ()));
42
49
this .currentJobGroupIds = new LinkedList <>();
43
- this .pollExecutors = Executors .newFixedThreadPool (model . getBatchSize () + 1 );
50
+ this .pollExecutors = Executors .newFixedThreadPool (localCapacity . get () + 1 );
44
51
this .jobsQueue = new LinkedBlockingDeque <>();
45
52
this .isPollJobGroup = new AtomicBoolean (false );
46
- this .localCapacity = model .getMaxNumSequence () / model .getMinWorkers ();
47
53
this .eventJobGroupIds = new LinkedBlockingDeque <>();
48
54
this .eventJobGroupIds .add ("" );
49
55
this .eventDispatcher = new Thread (new EventDispatcher ());
@@ -70,8 +76,9 @@ private void pollJobGroup() throws InterruptedException {
70
76
71
77
int quota =
72
78
Math .min (
73
- this .localCapacity - jobsQueue .size (),
74
- model .getPendingJobGroups ().size () / model .getMaxWorkers ());
79
+ this .localCapacity .get (),
80
+ Math .max (
81
+ 1 , model .getPendingJobGroups ().size () / model .getMaxWorkers ()));
75
82
if (quota > 0 && model .getPendingJobGroups ().size () > 0 ) {
76
83
model .getPendingJobGroups ().drainTo (tmpJobGroups , quota );
77
84
}
@@ -120,6 +127,8 @@ private void cleanJobGroup(String jobGroupId) {
120
127
logger .debug ("Clean jobGroup: {}" , jobGroupId );
121
128
if (jobGroupId != null ) {
122
129
model .removeJobGroup (jobGroupId );
130
+ pollQueueTasks .remove (jobGroupId );
131
+ localCapacity .incrementAndGet ();
123
132
}
124
133
}
125
134
@@ -176,6 +185,7 @@ public void shutdownExecutors() {
176
185
177
186
private void addJobGroup (String jobGroupId ) {
178
187
if (jobGroupId != null ) {
188
+ localCapacity .decrementAndGet ();
179
189
eventJobGroupIds .add (jobGroupId );
180
190
}
181
191
}
@@ -192,22 +202,39 @@ public void run() {
192
202
String jobGroupId =
193
203
eventJobGroupIds .poll (model .getMaxBatchDelay (), TimeUnit .MILLISECONDS );
194
204
if (jobGroupId == null || jobGroupId .isEmpty ()) {
195
- CompletableFuture .runAsync (
196
- () -> {
197
- try {
198
- pollJobGroup ();
199
- } catch (InterruptedException e ) {
200
- logger .error ("Failed to poll a job group" , e );
201
- }
202
- },
203
- pollExecutors );
205
+ // Skip fetching new job groups when no capacity is available
206
+ if (localCapacity .get () <= 0 ) {
207
+ continue ;
208
+ }
209
+ // Avoid duplicate poll tasks in the executor queue
210
+ if (pollQueueTasks .containsKey ("pollJobGroup" )
211
+ && !pollQueueTasks .get ("pollJobGroup" ).isDone ()) {
212
+ continue ;
213
+ }
214
+ CompletableFuture <Void > pollTask =
215
+ CompletableFuture .runAsync (
216
+ () -> {
217
+ try {
218
+ pollJobGroup ();
219
+ } catch (InterruptedException e ) {
220
+ logger .error ("Failed to poll a job group" , e );
221
+ }
222
+ },
223
+ pollExecutors );
224
+ pollQueueTasks .put ("pollJobGroup" , pollTask );
204
225
} else {
205
-
206
- CompletableFuture .runAsync (
207
- () -> {
208
- pollJobFromJobGroup (jobGroupId );
209
- },
210
- pollExecutors );
226
+ // Avoid duplicate poll tasks in the executor queue
227
+ if (pollQueueTasks .containsKey (jobGroupId )
228
+ && !pollQueueTasks .get (jobGroupId ).isDone ()) {
229
+ continue ;
230
+ }
231
+ CompletableFuture <Void > pollTask =
232
+ CompletableFuture .runAsync (
233
+ () -> {
234
+ pollJobFromJobGroup (jobGroupId );
235
+ },
236
+ pollExecutors );
237
+ pollQueueTasks .put (jobGroupId , pollTask );
211
238
}
212
239
} catch (InterruptedException e ) {
213
240
if (running .get ()) {
@@ -224,7 +251,7 @@ private void pollJobFromJobGroup(String jobGroupId) {
224
251
if (!jobGroup .isFinished ()) {
225
252
job = jobGroup .pollJob (model .getSequenceMaxIdleMSec ());
226
253
}
227
- if (job == null ) {
254
+ if (job == null || jobGroup . isFinished () ) {
228
255
// JobGroup expired, clean it.
229
256
cleanJobGroup (jobGroupId );
230
257
// intent to add new job groups.
0 commit comments