Skip to content

Commit 348082f

Browse files
Homogeneous balancing by accounting for in-flight requests (#9003)
1 parent b86b714 commit 348082f

File tree

3 files changed

+104
-18
lines changed

3 files changed

+104
-18
lines changed

distributed/scheduler.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -3170,7 +3170,7 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float:
31703170
deps = {dep for dep in ts.dependencies if dep not in ws.has_what}
31713171
else:
31723172
deps = (ts.dependencies or set()).difference(ws.has_what)
3173-
nbytes = sum(dts.nbytes for dts in deps)
3173+
nbytes = sum(dts.get_nbytes() for dts in deps)
31743174
return nbytes / self.bandwidth
31753175

31763176
def valid_workers(self, ts: TaskState) -> set[WorkerState] | None:
@@ -3265,13 +3265,13 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple:
32653265
"""Objective function to determine which worker should get the task
32663266
32673267
Minimize expected start time. If a tie then break with data storage.
3268-
"""
3269-
comm_bytes = sum(
3270-
dts.get_nbytes() for dts in ts.dependencies if ws not in (dts.who_has or ())
3271-
)
32723268
3269+
See Also
3270+
--------
3271+
WorkStealing.stealing_objective
3272+
"""
32733273
stack_time = ws.occupancy / ws.nthreads
3274-
start_time = stack_time + comm_bytes / self.bandwidth
3274+
start_time = stack_time + self.get_comm_cost(ts, ws)
32753275

32763276
if ts.actor:
32773277
return (len(ws.actors), start_time, ws.nbytes)

distributed/stealing.py

+36-12
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def balance(self) -> None:
472472
stealable.discard(ts)
473473
continue
474474
i += 1
475-
if not (thief := _get_thief(s, ts, potential_thieves)):
475+
if not (thief := self._get_thief(s, ts, potential_thieves)):
476476
continue
477477

478478
occ_thief = self._combined_occupancy(thief)
@@ -552,18 +552,42 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
552552
out.append(t)
553553
return out
554554

555+
def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...]:
556+
"""Objective function to determine which worker should get the task
555557
556-
def _get_thief(
557-
scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState]
558-
) -> WorkerState | None:
559-
valid_workers = scheduler.valid_workers(ts)
560-
if valid_workers is not None:
561-
valid_thieves = potential_thieves & valid_workers
562-
if valid_thieves:
563-
potential_thieves = valid_thieves
564-
elif not ts.loose_restrictions:
565-
return None
566-
return min(potential_thieves, key=partial(scheduler.worker_objective, ts))
558+
Minimize expected start time. If a tie then break with data storage.
559+
560+
Notes
561+
-----
562+
This method is a modified version of Scheduler.worker_objective that accounts
563+
for in-flight requests. It must be kept in sync for work-stealing to work correctly.
564+
565+
See Also
566+
--------
567+
Scheduler.worker_objective
568+
"""
569+
occupancy = self._combined_occupancy(
570+
ws
571+
) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws)
572+
if ts.actor:
573+
return (len(ws.actors), occupancy, ws.nbytes)
574+
else:
575+
return (occupancy, ws.nbytes)
576+
577+
def _get_thief(
578+
self,
579+
scheduler: SchedulerState,
580+
ts: TaskState,
581+
potential_thieves: set[WorkerState],
582+
) -> WorkerState | None:
583+
valid_workers = scheduler.valid_workers(ts)
584+
if valid_workers is not None:
585+
valid_thieves = potential_thieves & valid_workers
586+
if valid_thieves:
587+
potential_thieves = valid_thieves
588+
elif not ts.loose_restrictions:
589+
return None
590+
return min(potential_thieves, key=partial(self.stealing_objective, ts))
567591

568592

569593
fast_tasks = {

distributed/tests/test_steal.py

+62
Original file line numberDiff line numberDiff line change
@@ -1943,6 +1943,68 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers):
19431943
assert len(events) == 0
19441944

19451945

1946+
@gen_cluster(
1947+
nthreads=[("", 1)],
1948+
client=True,
1949+
config={"distributed.scheduler.worker-saturation": "inf"},
1950+
)
1951+
async def test_stealing_ogjective_accounts_for_in_flight(c, s, a):
1952+
"""Regression test that work-stealing's objective correctly accounts for in-flight data requests"""
1953+
in_event = Event()
1954+
block_event = Event()
1955+
1956+
def block(i: int, in_event: Event, block_event: Event) -> int:
1957+
in_event.set()
1958+
block_event.wait()
1959+
return i
1960+
1961+
# Stop stealing for deterministic testing
1962+
extension = s.extensions["stealing"]
1963+
await extension.stop()
1964+
1965+
try:
1966+
futs = c.map(block, range(20), in_event=in_event, block_event=block_event)
1967+
await in_event.wait()
1968+
1969+
async with Worker(s.address, nthreads=1) as b:
1970+
try:
1971+
await async_poll_for(lambda: s.idle, timeout=5)
1972+
wsA = s.workers[a.address]
1973+
wsB = s.workers[b.address]
1974+
ts = next(iter(wsA.processing))
1975+
1976+
# No in-flight requests, so both match
1977+
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
1978+
ts, wsA
1979+
)
1980+
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
1981+
ts, wsB
1982+
)
1983+
1984+
extension.balance()
1985+
assert extension.in_flight
1986+
# We move tasks from a to b
1987+
assert extension.stealing_objective(ts, wsA) < s.worker_objective(
1988+
ts, wsA
1989+
)
1990+
assert extension.stealing_objective(ts, wsB) > s.worker_objective(
1991+
ts, wsB
1992+
)
1993+
1994+
await async_poll_for(lambda: not extension.in_flight, timeout=5)
1995+
# No in-flight requests, so both match
1996+
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
1997+
ts, wsA
1998+
)
1999+
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
2000+
ts, wsB
2001+
)
2002+
finally:
2003+
await block_event.set()
2004+
finally:
2005+
await block_event.set()
2006+
2007+
19462008
@gen_cluster(
19472009
nthreads=[("", 1)],
19482010
client=True,

0 commit comments

Comments
 (0)