@@ -426,6 +426,10 @@ def balance(self) -> None:
426
426
log = []
427
427
start = time ()
428
428
429
+ # Pre-calculate all occupancies once, they don't change during balancing
430
+ occupancies = {ws : ws .occupancy for ws in s .workers .values ()}
431
+ combined_occupancy = partial (self ._combined_occupancy , occupancies = occupancies )
432
+
429
433
i = 0
430
434
# Paused and closing workers must never become thieves
431
435
potential_thieves = set (s .idle .values ())
@@ -434,21 +438,19 @@ def balance(self) -> None:
434
438
victim : WorkerState | None
435
439
potential_victims : set [WorkerState ] | list [WorkerState ] = s .saturated
436
440
if not potential_victims :
437
- potential_victims = topk (
438
- 10 , s .workers .values (), key = self ._combined_occupancy
439
- )
441
+ potential_victims = topk (10 , s .workers .values (), key = combined_occupancy )
440
442
potential_victims = [
441
443
ws
442
444
for ws in potential_victims
443
- if self . _combined_occupancy (ws ) > 0.2
445
+ if combined_occupancy (ws ) > 0.2
444
446
and self ._combined_nprocessing (ws ) > ws .nthreads
445
447
and ws not in potential_thieves
446
448
]
447
449
if not potential_victims :
448
450
return
449
451
if len (potential_victims ) < 20 :
450
452
potential_victims = sorted (
451
- potential_victims , key = self . _combined_occupancy , reverse = True
453
+ potential_victims , key = combined_occupancy , reverse = True
452
454
)
453
455
assert potential_victims
454
456
assert potential_thieves
@@ -472,11 +474,15 @@ def balance(self) -> None:
472
474
stealable .discard (ts )
473
475
continue
474
476
i += 1
475
- if not (thief := self ._get_thief (s , ts , potential_thieves )):
477
+ if not (
478
+ thief := self ._get_thief (
479
+ s , ts , potential_thieves , occupancies = occupancies
480
+ )
481
+ ):
476
482
continue
477
483
478
- occ_thief = self . _combined_occupancy (thief )
479
- occ_victim = self . _combined_occupancy (victim )
484
+ occ_thief = combined_occupancy (thief )
485
+ occ_victim = combined_occupancy (victim )
480
486
comm_cost_thief = self .scheduler .get_comm_cost (ts , thief )
481
487
comm_cost_victim = self .scheduler .get_comm_cost (ts , victim )
482
488
compute = self .scheduler ._get_prefix_duration (ts .prefix )
@@ -501,7 +507,7 @@ def balance(self) -> None:
501
507
self .metrics ["request_count_total" ][level ] += 1
502
508
self .metrics ["request_cost_total" ][level ] += cost
503
509
504
- occ_thief = self . _combined_occupancy (thief )
510
+ occ_thief = combined_occupancy (thief )
505
511
nproc_thief = self ._combined_nprocessing (thief )
506
512
507
513
# FIXME: In the worst case, the victim may have 3x the amount of work
@@ -515,7 +521,7 @@ def balance(self) -> None:
515
521
# properly clean up, we would not need this
516
522
stealable .discard (ts )
517
523
self .scheduler .check_idle_saturated (
518
- victim , occ = self . _combined_occupancy (victim )
524
+ victim , occ = combined_occupancy (victim )
519
525
)
520
526
521
527
if log :
@@ -525,8 +531,10 @@ def balance(self) -> None:
525
531
if s .digests :
526
532
s .digests ["steal-duration" ].add (stop - start )
527
533
528
- def _combined_occupancy (self , ws : WorkerState ) -> float :
529
- return ws .occupancy + self .in_flight_occupancy [ws ]
534
+ def _combined_occupancy (
535
+ self , ws : WorkerState , * , occupancies : dict [WorkerState , float ]
536
+ ) -> float :
537
+ return occupancies [ws ] + self .in_flight_occupancy [ws ]
530
538
531
539
def _combined_nprocessing (self , ws : WorkerState ) -> int :
532
540
return len (ws .processing ) + self .in_flight_tasks [ws ]
@@ -552,7 +560,9 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
552
560
out .append (t )
553
561
return out
554
562
555
- def stealing_objective (self , ts : TaskState , ws : WorkerState ) -> tuple [float , ...]:
563
+ def stealing_objective (
564
+ self , ts : TaskState , ws : WorkerState , * , occupancies : dict [WorkerState , float ]
565
+ ) -> tuple [float , ...]:
556
566
"""Objective function to determine which worker should get the task
557
567
558
568
Minimize expected start time. If a tie then break with data storage.
@@ -567,7 +577,8 @@ def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...
567
577
Scheduler.worker_objective
568
578
"""
569
579
occupancy = self ._combined_occupancy (
570
- ws
580
+ ws ,
581
+ occupancies = occupancies ,
571
582
) / ws .nthreads + self .scheduler .get_comm_cost (ts , ws )
572
583
if ts .actor :
573
584
return (len (ws .actors ), occupancy , ws .nbytes )
@@ -579,6 +590,8 @@ def _get_thief(
579
590
scheduler : SchedulerState ,
580
591
ts : TaskState ,
581
592
potential_thieves : set [WorkerState ],
593
+ * ,
594
+ occupancies : dict [WorkerState , float ],
582
595
) -> WorkerState | None :
583
596
valid_workers = scheduler .valid_workers (ts )
584
597
if valid_workers is not None :
@@ -587,7 +600,10 @@ def _get_thief(
587
600
potential_thieves = valid_thieves
588
601
elif not ts .loose_restrictions :
589
602
return None
590
- return min (potential_thieves , key = partial (self .stealing_objective , ts ))
603
+ return min (
604
+ potential_thieves ,
605
+ key = partial (self .stealing_objective , ts , occupancies = occupancies ),
606
+ )
591
607
592
608
593
609
fast_tasks = {
0 commit comments