Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Homogeneous balancing by accounting for in-flight requests #9003

Merged
merged 4 commits into from
Feb 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
@@ -3170,7 +3170,7 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float:
deps = {dep for dep in ts.dependencies if dep not in ws.has_what}
else:
deps = (ts.dependencies or set()).difference(ws.has_what)
nbytes = sum(dts.nbytes for dts in deps)
nbytes = sum(dts.get_nbytes() for dts in deps)
return nbytes / self.bandwidth

def valid_workers(self, ts: TaskState) -> set[WorkerState] | None:
@@ -3265,13 +3265,13 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple:
"""Objective function to determine which worker should get the task

Minimize expected start time. If a tie then break with data storage.
"""
comm_bytes = sum(
dts.get_nbytes() for dts in ts.dependencies if ws not in (dts.who_has or ())
)

See Also
--------
WorkStealing.stealing_objective
"""
stack_time = ws.occupancy / ws.nthreads
start_time = stack_time + comm_bytes / self.bandwidth
start_time = stack_time + self.get_comm_cost(ts, ws)

if ts.actor:
return (len(ws.actors), start_time, ws.nbytes)
48 changes: 36 additions & 12 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
@@ -472,7 +472,7 @@
stealable.discard(ts)
continue
i += 1
if not (thief := _get_thief(s, ts, potential_thieves)):
if not (thief := self._get_thief(s, ts, potential_thieves)):
continue

occ_thief = self._combined_occupancy(thief)
@@ -552,18 +552,42 @@
out.append(t)
return out

def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...]:
"""Objective function to determine which worker should get the task

def _get_thief(
scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState]
) -> WorkerState | None:
valid_workers = scheduler.valid_workers(ts)
if valid_workers is not None:
valid_thieves = potential_thieves & valid_workers
if valid_thieves:
potential_thieves = valid_thieves
elif not ts.loose_restrictions:
return None
return min(potential_thieves, key=partial(scheduler.worker_objective, ts))
Minimize expected start time. If a tie then break with data storage.

Notes
-----
This method is a modified version of Scheduler.worker_objective that accounts
for in-flight requests. It must be kept in sync for work-stealing to work correctly.

See Also
--------
Scheduler.worker_objective
"""
occupancy = self._combined_occupancy(
ws
) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws)
if ts.actor:
return (len(ws.actors), occupancy, ws.nbytes)

Check warning on line 573 in distributed/stealing.py

Codecov / codecov/patch

distributed/stealing.py#L573

Added line #L573 was not covered by tests
else:
return (occupancy, ws.nbytes)

def _get_thief(
self,
scheduler: SchedulerState,
ts: TaskState,
potential_thieves: set[WorkerState],
) -> WorkerState | None:
valid_workers = scheduler.valid_workers(ts)
if valid_workers is not None:
valid_thieves = potential_thieves & valid_workers
if valid_thieves:
potential_thieves = valid_thieves
elif not ts.loose_restrictions:
return None
return min(potential_thieves, key=partial(self.stealing_objective, ts))


fast_tasks = {
62 changes: 62 additions & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
@@ -1943,6 +1943,68 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers):
assert len(events) == 0


@gen_cluster(
nthreads=[("", 1)],
client=True,
config={"distributed.scheduler.worker-saturation": "inf"},
)
async def test_stealing_ogjective_accounts_for_in_flight(c, s, a):
"""Regression test that work-stealing's objective correctly accounts for in-flight data requests"""
in_event = Event()
block_event = Event()

def block(i: int, in_event: Event, block_event: Event) -> int:
in_event.set()
block_event.wait()
return i

# Stop stealing for deterministic testing
extension = s.extensions["stealing"]
await extension.stop()

try:
futs = c.map(block, range(20), in_event=in_event, block_event=block_event)
await in_event.wait()

async with Worker(s.address, nthreads=1) as b:
try:
await async_poll_for(lambda: s.idle, timeout=5)
wsA = s.workers[a.address]
wsB = s.workers[b.address]
ts = next(iter(wsA.processing))

# No in-flight requests, so both match
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
ts, wsA
)
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
ts, wsB
)

extension.balance()
assert extension.in_flight
# We move tasks from a to b
assert extension.stealing_objective(ts, wsA) < s.worker_objective(
ts, wsA
)
assert extension.stealing_objective(ts, wsB) > s.worker_objective(
ts, wsB
)

await async_poll_for(lambda: not extension.in_flight, timeout=5)
# No in-flight requests, so both match
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
ts, wsA
)
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
ts, wsB
)
finally:
await block_event.set()
finally:
await block_event.set()


@gen_cluster(
nthreads=[("", 1)],
client=True,