Skip to content

Commit e5711c6

Browse files
Cache occupancy in WorkStealing.balance() (#9005)
1 parent 348082f commit e5711c6

File tree

2 files changed

+53
-35
lines changed

2 files changed

+53
-35
lines changed

distributed/stealing.py

+31-15
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,10 @@ def balance(self) -> None:
426426
log = []
427427
start = time()
428428

429+
# Pre-calculate all occupancies once, they don't change during balancing
430+
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
431+
combined_occupancy = partial(self._combined_occupancy, occupancies=occupancies)
432+
429433
i = 0
430434
# Paused and closing workers must never become thieves
431435
potential_thieves = set(s.idle.values())
@@ -434,21 +438,19 @@ def balance(self) -> None:
434438
victim: WorkerState | None
435439
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated
436440
if not potential_victims:
437-
potential_victims = topk(
438-
10, s.workers.values(), key=self._combined_occupancy
439-
)
441+
potential_victims = topk(10, s.workers.values(), key=combined_occupancy)
440442
potential_victims = [
441443
ws
442444
for ws in potential_victims
443-
if self._combined_occupancy(ws) > 0.2
445+
if combined_occupancy(ws) > 0.2
444446
and self._combined_nprocessing(ws) > ws.nthreads
445447
and ws not in potential_thieves
446448
]
447449
if not potential_victims:
448450
return
449451
if len(potential_victims) < 20:
450452
potential_victims = sorted(
451-
potential_victims, key=self._combined_occupancy, reverse=True
453+
potential_victims, key=combined_occupancy, reverse=True
452454
)
453455
assert potential_victims
454456
assert potential_thieves
@@ -472,11 +474,15 @@ def balance(self) -> None:
472474
stealable.discard(ts)
473475
continue
474476
i += 1
475-
if not (thief := self._get_thief(s, ts, potential_thieves)):
477+
if not (
478+
thief := self._get_thief(
479+
s, ts, potential_thieves, occupancies=occupancies
480+
)
481+
):
476482
continue
477483

478-
occ_thief = self._combined_occupancy(thief)
479-
occ_victim = self._combined_occupancy(victim)
484+
occ_thief = combined_occupancy(thief)
485+
occ_victim = combined_occupancy(victim)
480486
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
481487
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
482488
compute = self.scheduler._get_prefix_duration(ts.prefix)
@@ -501,7 +507,7 @@ def balance(self) -> None:
501507
self.metrics["request_count_total"][level] += 1
502508
self.metrics["request_cost_total"][level] += cost
503509

504-
occ_thief = self._combined_occupancy(thief)
510+
occ_thief = combined_occupancy(thief)
505511
nproc_thief = self._combined_nprocessing(thief)
506512

507513
# FIXME: In the worst case, the victim may have 3x the amount of work
@@ -515,7 +521,7 @@ def balance(self) -> None:
515521
# properly clean up, we would not need this
516522
stealable.discard(ts)
517523
self.scheduler.check_idle_saturated(
518-
victim, occ=self._combined_occupancy(victim)
524+
victim, occ=combined_occupancy(victim)
519525
)
520526

521527
if log:
@@ -525,8 +531,10 @@ def balance(self) -> None:
525531
if s.digests:
526532
s.digests["steal-duration"].add(stop - start)
527533

528-
def _combined_occupancy(self, ws: WorkerState) -> float:
529-
return ws.occupancy + self.in_flight_occupancy[ws]
534+
def _combined_occupancy(
535+
self, ws: WorkerState, *, occupancies: dict[WorkerState, float]
536+
) -> float:
537+
return occupancies[ws] + self.in_flight_occupancy[ws]
530538

531539
def _combined_nprocessing(self, ws: WorkerState) -> int:
532540
return len(ws.processing) + self.in_flight_tasks[ws]
@@ -552,7 +560,9 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
552560
out.append(t)
553561
return out
554562

555-
def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...]:
563+
def stealing_objective(
564+
self, ts: TaskState, ws: WorkerState, *, occupancies: dict[WorkerState, float]
565+
) -> tuple[float, ...]:
556566
"""Objective function to determine which worker should get the task
557567
558568
Minimize expected start time. If a tie then break with data storage.
@@ -567,7 +577,8 @@ def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...
567577
Scheduler.worker_objective
568578
"""
569579
occupancy = self._combined_occupancy(
570-
ws
580+
ws,
581+
occupancies=occupancies,
571582
) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws)
572583
if ts.actor:
573584
return (len(ws.actors), occupancy, ws.nbytes)
@@ -579,6 +590,8 @@ def _get_thief(
579590
scheduler: SchedulerState,
580591
ts: TaskState,
581592
potential_thieves: set[WorkerState],
593+
*,
594+
occupancies: dict[WorkerState, float],
582595
) -> WorkerState | None:
583596
valid_workers = scheduler.valid_workers(ts)
584597
if valid_workers is not None:
@@ -587,7 +600,10 @@ def _get_thief(
587600
potential_thieves = valid_thieves
588601
elif not ts.loose_restrictions:
589602
return None
590-
return min(potential_thieves, key=partial(self.stealing_objective, ts))
603+
return min(
604+
potential_thieves,
605+
key=partial(self.stealing_objective, ts, occupancies=occupancies),
606+
)
591607

592608

593609
fast_tasks = {

distributed/tests/test_steal.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -1948,7 +1948,7 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers):
19481948
client=True,
19491949
config={"distributed.scheduler.worker-saturation": "inf"},
19501950
)
1951-
async def test_stealing_ogjective_accounts_for_in_flight(c, s, a):
1951+
async def test_stealing_objective_accounts_for_in_flight(c, s, a):
19521952
"""Regression test that work-stealing's objective correctly accounts for in-flight data requests"""
19531953
in_event = Event()
19541954
block_event = Event()
@@ -1973,32 +1973,34 @@ def block(i: int, in_event: Event, block_event: Event) -> int:
19731973
wsB = s.workers[b.address]
19741974
ts = next(iter(wsA.processing))
19751975

1976+
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
19761977
# No in-flight requests, so both match
1977-
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
1978-
ts, wsA
1979-
)
1980-
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
1981-
ts, wsB
1982-
)
1978+
assert extension.stealing_objective(
1979+
ts, wsA, occupancies=occupancies
1980+
) == s.worker_objective(ts, wsA)
1981+
assert extension.stealing_objective(
1982+
ts, wsB, occupancies=occupancies
1983+
) == s.worker_objective(ts, wsB)
19831984

19841985
extension.balance()
19851986
assert extension.in_flight
19861987
# We move tasks from a to b
1987-
assert extension.stealing_objective(ts, wsA) < s.worker_objective(
1988-
ts, wsA
1989-
)
1990-
assert extension.stealing_objective(ts, wsB) > s.worker_objective(
1991-
ts, wsB
1992-
)
1988+
assert extension.stealing_objective(
1989+
ts, wsA, occupancies=occupancies
1990+
) < s.worker_objective(ts, wsA)
1991+
assert extension.stealing_objective(
1992+
ts, wsB, occupancies=occupancies
1993+
) > s.worker_objective(ts, wsB)
19931994

19941995
await async_poll_for(lambda: not extension.in_flight, timeout=5)
1996+
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
19951997
# No in-flight requests, so both match
1996-
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
1997-
ts, wsA
1998-
)
1999-
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
2000-
ts, wsB
2001-
)
1998+
assert extension.stealing_objective(
1999+
ts, wsA, occupancies=occupancies
2000+
) == s.worker_objective(ts, wsA)
2001+
assert extension.stealing_objective(
2002+
ts, wsB, occupancies=occupancies
2003+
) == s.worker_objective(ts, wsB)
20022004
finally:
20032005
await block_event.set()
20042006
finally:
@@ -2031,7 +2033,7 @@ def block(i: int, in_event: Event, block_event: Event) -> int:
20312033
await in_event.wait()
20322034

20332035
# This is the pre-condition for the observed problem:
2034-
# There are tasks that execute fox a long time but do not have an average
2036+
# There are tasks that execute for a long time but do not have an average
20352037
s.task_prefixes["block"].add_exec_time(100)
20362038
assert s.task_prefixes["block"].duration_average == -1
20372039

0 commit comments

Comments
 (0)