From 68bdf8b137489820b2b3fdf4d9abb489cc9a3ce5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 30 Jan 2025 15:22:57 +0100 Subject: [PATCH 1/3] Add stealing objective --- distributed/scheduler.py | 8 ++------ distributed/stealing.py | 41 +++++++++++++++++++++++++++------------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e6fb3278561..eef940279fc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3168,7 +3168,7 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: deps = {dep for dep in ts.dependencies if dep not in ws.has_what} else: deps = (ts.dependencies or set()).difference(ws.has_what) - nbytes = sum(dts.nbytes for dts in deps) + nbytes = sum(dts.get_nbytes() for dts in deps) return nbytes / self.bandwidth def get_task_duration(self, ts: TaskState) -> float: @@ -3284,12 +3284,8 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: Minimize expected start time. If a tie then break with data storage. """ - comm_bytes = sum( - dts.get_nbytes() for dts in ts.dependencies if ws not in (dts.who_has or ()) - ) - stack_time = ws.occupancy / ws.nthreads - start_time = stack_time + comm_bytes / self.bandwidth + start_time = stack_time + self.get_comm_cost(ts, ws) if ts.actor: return (len(ws.actors), start_time, ws.nbytes) diff --git a/distributed/stealing.py b/distributed/stealing.py index 7e3711b8e2b..9934344c084 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -450,7 +450,7 @@ def balance(self) -> None: stealable.discard(ts) continue i += 1 - if not (thief := _get_thief(s, ts, potential_thieves)): + if not (thief := self._get_thief(s, ts, potential_thieves)): continue occ_thief = self._combined_occupancy(thief) @@ -528,18 +528,33 @@ def story(self, *keys_or_ts: str | TaskState) -> list: out.append(t) return out - -def _get_thief( - scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState] -) -> WorkerState | None: - valid_workers = scheduler.valid_workers(ts) - if valid_workers is not None: - valid_thieves = potential_thieves & valid_workers - if valid_thieves: - potential_thieves = valid_thieves - elif not ts.loose_restrictions: - return None - return min(potential_thieves, key=partial(scheduler.worker_objective, ts)) + def stealing_objective( + self, scheduler: SchedulerState, ts: TaskState, ws: WorkerState + ) -> tuple[float, ...]: + occupancy = self._combined_occupancy( + ws + ) / ws.nthreads + scheduler.get_comm_cost(ts, ws) + if ts.actor: + return (len(ws.actors), occupancy, ws.nbytes) + else: + return (occupancy, ws.nbytes) + + def _get_thief( + self, + scheduler: SchedulerState, + ts: TaskState, + potential_thieves: set[WorkerState], + ) -> WorkerState | None: + valid_workers = scheduler.valid_workers(ts) + if valid_workers is not None: + valid_thieves = potential_thieves & valid_workers + if valid_thieves: + potential_thieves = valid_thieves + elif not ts.loose_restrictions: + return None + return min( + potential_thieves, key=partial(self.stealing_objective, scheduler, ts) + ) fast_tasks = { From 1e1fdf96b5452c64ac7877e06c74a4e30a3ff241 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Feb 2025 08:11:19 +0100 Subject: [PATCH 2/3] Comments and test --- distributed/scheduler.py | 4 +++ distributed/stealing.py | 19 ++++++++++-- distributed/tests/test_steal.py | 52 +++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index eef940279fc..2c0ff5ef8b1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3283,6 +3283,10 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: """Objective function to determine which worker should get the task Minimize expected start time. If a tie then break with data storage. + + See Also + -------- + WorkStealing.stealing_objective """ stack_time = ws.occupancy / ws.nthreads start_time = stack_time + self.get_comm_cost(ts, ws) diff --git a/distributed/stealing.py b/distributed/stealing.py index 9934344c084..f869a59cfbe 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -529,11 +529,24 @@ def story(self, *keys_or_ts: str | TaskState) -> list: return out def stealing_objective( - self, scheduler: SchedulerState, ts: TaskState, ws: WorkerState + self, ts: TaskState, ws: WorkerState ) -> tuple[float, ...]: + """Objective function to determine which worker should get the task + + Minimize expected start time. If a tie then break with data storage. + + Notes + ----- + This method is a modified version of Scheduler.worker_objective that accounts + for in-flight requests. It must be kept in sync for work-stealing to work correctly. + + See Also + -------- + Scheduler.worker_objective + """ occupancy = self._combined_occupancy( ws - ) / ws.nthreads + scheduler.get_comm_cost(ts, ws) + ) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws) if ts.actor: return (len(ws.actors), occupancy, ws.nbytes) else: @@ -553,7 +566,7 @@ def _get_thief( elif not ts.loose_restrictions: return None return min( - potential_thieves, key=partial(self.stealing_objective, scheduler, ts) + potential_thieves, key=partial(self.stealing_objective, ts) ) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3976857c9e7..4eb9e358f5c 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -38,6 +38,7 @@ from distributed.utils_test import ( NO_AMM, BlockedGetData, + async_poll_for, captured_logger, freeze_batched_send, gen_cluster, @@ -1877,3 +1878,54 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): await c.gather(futs) events = s.get_events("stealing") assert len(events) == 0 + + +@gen_cluster( + nthreads=[("", 1)], + client=True, + config={"distributed.scheduler.worker-saturation": "inf"}, +) +async def test_stealing_ogjective_accounts_for_in_flight(c, s, a): + """Regression test that work-stealing's objective correctly accounts for in-flight data requests + """ + in_event = Event() + block_event = Event() + + def block(i: int, in_event: Event, block_event: Event) -> int: + in_event.set() + block_event.wait() + return i + + # Stop stealing for deterministic testing + extension = s.extensions["stealing"] + await extension.stop() + + try: + futs = c.map(block, range(20), in_event=in_event, block_event=block_event) + await in_event.wait() + + async with Worker(s.address, nthreads=1) as b: + try: + await async_poll_for(lambda: s.idle, timeout=5) + wsA = s.workers[a.address] + wsB = s.workers[b.address] + ts = next(iter(wsA.processing)) + + # No in-flight requests, so both match + assert extension.stealing_objective(ts, wsA) == s.worker_objective(ts, wsA) + assert extension.stealing_objective(ts, wsB) == s.worker_objective(ts, wsB) + + extension.balance() + assert extension.in_flight + # We move tasks from a to b + assert extension.stealing_objective(ts, wsA) < s.worker_objective(ts, wsA) + assert extension.stealing_objective(ts, wsB) > s.worker_objective(ts, wsB) + + await async_poll_for(lambda: not extension.in_flight, timeout=5) + # No in-flight requests, so both match + assert extension.stealing_objective(ts, wsA) == s.worker_objective(ts, wsA) + assert extension.stealing_objective(ts, wsB) == s.worker_objective(ts, wsB) + finally: + await block_event.set() + finally: + await block_event.set() From 2827f6f6a5266e80fac289a82caf67b570f04334 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Feb 2025 08:18:42 +0100 Subject: [PATCH 3/3] Comments and test --- distributed/stealing.py | 8 ++------ distributed/tests/test_steal.py | 27 +++++++++++++++++++-------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index f869a59cfbe..30f81a09a1e 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -528,9 +528,7 @@ def story(self, *keys_or_ts: str | TaskState) -> list: out.append(t) return out - def stealing_objective( - self, ts: TaskState, ws: WorkerState - ) -> tuple[float, ...]: + def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...]: """Objective function to determine which worker should get the task Minimize expected start time. If a tie then break with data storage. @@ -565,9 +563,7 @@ def _get_thief( potential_thieves = valid_thieves elif not ts.loose_restrictions: return None - return min( - potential_thieves, key=partial(self.stealing_objective, ts) - ) + return min(potential_thieves, key=partial(self.stealing_objective, ts)) fast_tasks = { diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 4eb9e358f5c..780cea5d647 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1886,8 +1886,7 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): config={"distributed.scheduler.worker-saturation": "inf"}, ) async def test_stealing_ogjective_accounts_for_in_flight(c, s, a): - """Regression test that work-stealing's objective correctly accounts for in-flight data requests - """ + """Regression test that work-stealing's objective correctly accounts for in-flight data requests""" in_event = Event() block_event = Event() @@ -1912,19 +1911,31 @@ def block(i: int, in_event: Event, block_event: Event) -> int: ts = next(iter(wsA.processing)) # No in-flight requests, so both match - assert extension.stealing_objective(ts, wsA) == s.worker_objective(ts, wsA) - assert extension.stealing_objective(ts, wsB) == s.worker_objective(ts, wsB) + assert extension.stealing_objective(ts, wsA) == s.worker_objective( + ts, wsA + ) + assert extension.stealing_objective(ts, wsB) == s.worker_objective( + ts, wsB + ) extension.balance() assert extension.in_flight # We move tasks from a to b - assert extension.stealing_objective(ts, wsA) < s.worker_objective(ts, wsA) - assert extension.stealing_objective(ts, wsB) > s.worker_objective(ts, wsB) + assert extension.stealing_objective(ts, wsA) < s.worker_objective( + ts, wsA + ) + assert extension.stealing_objective(ts, wsB) > s.worker_objective( + ts, wsB + ) await async_poll_for(lambda: not extension.in_flight, timeout=5) # No in-flight requests, so both match - assert extension.stealing_objective(ts, wsA) == s.worker_objective(ts, wsA) - assert extension.stealing_objective(ts, wsB) == s.worker_objective(ts, wsB) + assert extension.stealing_objective(ts, wsA) == s.worker_objective( + ts, wsA + ) + assert extension.stealing_objective(ts, wsB) == s.worker_objective( + ts, wsB + ) finally: await block_event.set() finally: