@@ -234,13 +234,26 @@ def recalculate_cost(self, ts: TaskState) -> None:
234
234
235
235
def put_key_in_stealable (self , ts : TaskState ) -> None :
236
236
cost_multiplier , level = self .steal_time_ratio (ts )
237
- if cost_multiplier is not None :
238
- assert level is not None
239
- assert ts .processing_on
240
- ws = ts .processing_on
241
- worker = ws .address
242
- self .stealable [worker ][level ].add (ts )
243
- self .key_stealable [ts ] = (worker , level )
237
+
238
+ prefix = ts .prefix
239
+ duration = self .scheduler ._get_prefix_duration (prefix )
240
+ if cost_multiplier is None :
241
+ return
242
+
243
+ assert level is not None
244
+ assert ts .processing_on
245
+ ws = ts .processing_on
246
+ worker = ws .address
247
+ self .stealable [worker ][level ].add (ts )
248
+ self .key_stealable [ts ] = (worker , level )
249
+
250
+ if duration == ts .prefix .duration_average :
251
+ return
252
+
253
+ s = self .unknown_durations .get (prefix .name )
254
+ if s is None :
255
+ self .unknown_durations [prefix .name ] = s = set ()
256
+ s .add (ts )
244
257
245
258
def remove_key_from_stealable (self , ts : TaskState ) -> None :
246
259
result = self .key_stealable .pop (ts , None )
@@ -266,7 +279,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non
266
279
if not ts .dependencies : # no dependencies fast path
267
280
return 0 , 0
268
281
269
- compute_time = self .get_task_duration (ts )
282
+ compute_time = self .scheduler . _get_prefix_duration (ts . prefix )
270
283
271
284
if not compute_time :
272
285
# occupancy/ws.processing[ts] is only allowed to be zero for
@@ -312,12 +325,9 @@ def move_task_request(
312
325
313
326
# TODO: occupancy no longer concats linearly so we can't easily
314
327
# assume that the network cost would go down by that much
315
- victim_duration = self .get_task_duration (ts ) + self .scheduler .get_comm_cost (
316
- ts , victim
317
- )
318
- thief_duration = self .get_task_duration (ts ) + self .scheduler .get_comm_cost (
319
- ts , thief
320
- )
328
+ compute = self .scheduler ._get_prefix_duration (ts .prefix )
329
+ victim_duration = compute + self .scheduler .get_comm_cost (ts , victim )
330
+ thief_duration = compute + self .scheduler .get_comm_cost (ts , thief )
321
331
322
332
self .scheduler .stream_comms [victim .address ].send (
323
333
{"op" : "steal-request" , "key" : key , "stimulus_id" : stimulus_id }
@@ -468,8 +478,7 @@ def balance(self) -> None:
468
478
occ_victim = self ._combined_occupancy (victim )
469
479
comm_cost_thief = self .scheduler .get_comm_cost (ts , thief )
470
480
comm_cost_victim = self .scheduler .get_comm_cost (ts , victim )
471
- compute = self .get_task_duration (ts )
472
-
481
+ compute = self .scheduler ._get_prefix_duration (ts .prefix )
473
482
if (
474
483
occ_thief + comm_cost_thief + compute
475
484
<= occ_victim - (comm_cost_victim + compute ) / 2
@@ -542,30 +551,6 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
542
551
out .append (t )
543
552
return out
544
553
545
- def get_task_duration (self , ts : TaskState ) -> float :
546
- """Get the estimated computation cost of the given task (not including
547
- any communication cost).
548
-
549
- If no data has been observed, value of
550
- `distributed.scheduler.default-task-durations` are used. If none is set
551
- for this task, `distributed.scheduler.unknown-task-duration` is used
552
- instead.
553
- """
554
- prefix = ts .prefix
555
- duration : float = prefix .duration_average
556
- if duration >= 0 :
557
- return duration
558
- if prefix .max_exec_time > 0 :
559
- duration = 2 * prefix .max_exec_time
560
- else :
561
- duration = self .scheduler .UNKNOWN_TASK_DURATION
562
-
563
- s = self .unknown_durations .get (prefix .name )
564
- if s is None :
565
- self .unknown_durations [prefix .name ] = s = set ()
566
- s .add (ts )
567
- return duration
568
-
569
554
570
555
def _get_thief (
571
556
scheduler : SchedulerState , ts : TaskState , potential_thieves : set [WorkerState ]
0 commit comments