From 14ac4788db818a6b5215fcfa1c2a71d2977b96b1 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 3 Feb 2025 15:04:43 +0100 Subject: [PATCH 01/10] Move task duration calculation to stealing --- distributed/scheduler.py | 46 ++---------------------- distributed/stealing.py | 54 ++++++++++++++++++++++++----- distributed/worker_state_machine.py | 6 ---- 3 files changed, 49 insertions(+), 57 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e6fb3278561..2eae1e0cec4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1674,9 +1674,6 @@ class SchedulerState: #: Subset of tasks that exist in memory on more than one worker replicated_tasks: set[TaskState] - #: Tasks with unknown duration, grouped by prefix - #: {task prefix: {ts, ts, ...}} - unknown_durations: dict[str, set[TaskState]] task_groups: dict[str, TaskGroup] task_prefixes: dict[str, TaskPrefix] task_metadata: dict[Key, Any] @@ -1776,7 +1773,6 @@ def __init__( self.task_metadata = {} self.total_nthreads = 0 self.total_nthreads_history = [(time(), 0)] - self.unknown_durations = {} self.queued = queued self.unrunnable = unrunnable self.validate = validate @@ -1855,7 +1851,6 @@ def __pdict__(self) -> dict[str, Any]: "unrunnable": self.unrunnable, "queued": self.queued, "n_tasks": self.n_tasks, - "unknown_durations": self.unknown_durations, "validate": self.validate, "tasks": self.tasks, "task_groups": self.task_groups, @@ -1907,7 +1902,6 @@ def _clear_task_state(self) -> None: self.task_prefixes, self.task_groups, self.task_metadata, - self.unknown_durations, self.replicated_tasks, ): collection.clear() @@ -2536,13 +2530,6 @@ def _transition_processing_memory( action=startstop["action"], ) - s = self.unknown_durations.pop(ts.prefix.name, set()) - steal = self.extensions.get("stealing") - if steal: - for tts in s: - if tts.processing_on: - steal.recalculate_cost(tts) - ############################ # Update State Information # ############################ @@ -3171,26 +3158,6 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: nbytes = sum(dts.nbytes for dts in deps) return nbytes / self.bandwidth - def get_task_duration(self, ts: TaskState) -> float: - """Get the estimated computation cost of the given task (not including - any communication cost). - - If no data has been observed, value of - `distributed.scheduler.default-task-durations` are used. If none is set - for this task, `distributed.scheduler.unknown-task-duration` is used - instead. - """ - prefix = ts.prefix - duration: float = prefix.duration_average - if duration >= 0: - return duration - - s = self.unknown_durations.get(prefix.name) - if s is None: - self.unknown_durations[prefix.name] = s = set() - s.add(ts) - return self.UNKNOWN_TASK_DURATION - def valid_workers(self, ts: TaskState) -> set[WorkerState] | None: """Return set of currently valid workers for key @@ -3569,12 +3536,8 @@ def _client_releases_keys( elif ts.state != "erred" and not ts.waiters: recommendations[ts.key] = "released" - def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: + def _task_to_msg(self, ts: TaskState) -> dict[str, Any]: """Convert a single computational task to a message""" - # FIXME: The duration attribute is not used on worker. We could save ourselves the - # time to compute and submit this - if duration < 0: - duration = self.get_task_duration(ts) ts.run_id = next(TaskState._run_id_iterator) assert ts.priority, ts msg: dict[str, Any] = { @@ -3582,7 +3545,6 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: "key": ts.key, "run_id": ts.run_id, "priority": ts.priority, - "duration": duration, "stimulus_id": f"compute-task-{time()}", "who_has": { dts.key: tuple(ws.address for ws in (dts.who_has or ())) @@ -6003,12 +5965,10 @@ async def remove_client_from_events() -> None: cleanup_delay, remove_client_from_events ) - def send_task_to_worker( - self, worker: str, ts: TaskState, duration: float = -1 - ) -> None: + def send_task_to_worker(self, worker: str, ts: TaskState) -> None: """Send a single computational task to a worker""" try: - msg = self._task_to_msg(ts, duration) + msg = self._task_to_msg(ts) self.worker_send(worker, msg) except Exception as e: logger.exception(e) diff --git a/distributed/stealing.py b/distributed/stealing.py index 7e3711b8e2b..6fbdf973260 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -88,6 +88,9 @@ class WorkStealing(SchedulerPlugin): metrics: dict[str, dict[int, float]] _in_flight_event: asyncio.Event _request_counter: int + #: Tasks with unknown duration, grouped by prefix + #: {task prefix: {ts, ts, ...}} + unknown_durations: dict[str, set[TaskState]] def __init__(self, scheduler: Scheduler): self.scheduler = scheduler @@ -111,6 +114,7 @@ def __init__(self, scheduler: Scheduler): self.in_flight_occupancy = defaultdict(int) self.in_flight_tasks = defaultdict(int) self._in_flight_event = asyncio.Event() + self.unknown_durations = {} self.metrics = { "request_count_total": defaultdict(int), "request_cost_total": defaultdict(int), @@ -188,6 +192,13 @@ def transition( ts = self.scheduler.tasks[key] self.remove_key_from_stealable(ts) self._remove_from_in_flight(ts) + + if finish == "memory": + s = self.unknown_durations.pop(ts.prefix.name, set()) + for tts in s: + if tts.processing_on: + self.recalculate_cost(tts) + if finish == "processing": ts = self.scheduler.tasks[key] self.put_key_in_stealable(ts) @@ -255,7 +266,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non if not ts.dependencies: # no dependencies fast path return 0, 0 - compute_time = self.scheduler.get_task_duration(ts) + compute_time = self.get_task_duration(ts) if not compute_time: # occupancy/ws.processing[ts] is only allowed to be zero for @@ -301,12 +312,12 @@ def move_task_request( # TODO: occupancy no longer concats linearly so we can't easily # assume that the network cost would go down by that much - victim_duration = self.scheduler.get_task_duration( - ts - ) + self.scheduler.get_comm_cost(ts, victim) - thief_duration = self.scheduler.get_task_duration( - ts - ) + self.scheduler.get_comm_cost(ts, thief) + victim_duration = self.get_task_duration(ts) + self.scheduler.get_comm_cost( + ts, victim + ) + thief_duration = self.get_task_duration(ts) + self.scheduler.get_comm_cost( + ts, thief + ) self.scheduler.stream_comms[victim.address].send( {"op": "steal-request", "key": key, "stimulus_id": stimulus_id} @@ -457,7 +468,7 @@ def balance(self) -> None: occ_victim = self._combined_occupancy(victim) comm_cost_thief = self.scheduler.get_comm_cost(ts, thief) comm_cost_victim = self.scheduler.get_comm_cost(ts, victim) - compute = self.scheduler.get_task_duration(ts) + compute = self.get_task_duration(ts) if ( occ_thief + comm_cost_thief + compute @@ -483,6 +494,8 @@ def balance(self) -> None: occ_thief = self._combined_occupancy(thief) nproc_thief = self._combined_nprocessing(thief) + # FIXME: In the worst case, the victim may have 3x the amount of work + # of the thief when this aborts balancing. if not self.scheduler.is_unoccupied( thief, occ_thief, nproc_thief ): @@ -514,6 +527,7 @@ def restart(self, scheduler: Any) -> None: s.clear() self.key_stealable.clear() + self.unknown_durations.clear() def story(self, *keys_or_ts: str | TaskState) -> list: keys = {key.key if not isinstance(key, str) else key for key in keys_or_ts} @@ -541,6 +555,30 @@ def _get_thief( return None return min(potential_thieves, key=partial(scheduler.worker_objective, ts)) + def get_task_duration(self, ts: TaskState) -> float: + """Get the estimated computation cost of the given task (not including + any communication cost). + + If no data has been observed, value of + `distributed.scheduler.default-task-durations` are used. If none is set + for this task, `distributed.scheduler.unknown-task-duration` is used + instead. + """ + prefix = ts.prefix + duration: float = prefix.duration_average + if duration >= 0: + return duration + if prefix.max_exec_time > 0: + duration = 2 * prefix.max_exec_time + else: + duration = self.scheduler.UNKNOWN_TASK_DURATION + + s = self.unknown_durations.get(prefix.name) + if s is None: + self.unknown_durations[prefix.name] = s = set() + s.add(ts) + return duration + fast_tasks = { k diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 3507fb93797..5e34acd85e3 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -236,8 +236,6 @@ class TaskState: #: The next state of the task. It is not None iff :attr:`state` == resumed. next: Literal["fetch", "waiting", None] = None - #: Expected duration of the task - duration: float | None = None #: The priority this task given by the scheduler. Determines run order. priority: tuple[int, ...] | None = None #: Addresses of workers that we believe have this data @@ -736,7 +734,6 @@ class ComputeTaskEvent(StateMachineEvent): who_has: dict[Key, Collection[str]] nbytes: dict[Key, int] priority: tuple[int, ...] - duration: float run_spec: T_runspec | None resource_restrictions: dict[str, float] actor: bool @@ -782,7 +779,6 @@ def dummy( who_has: dict[Key, Collection[str]] | None = None, nbytes: dict[Key, int] | None = None, priority: tuple[int, ...] = (0,), - duration: float = 1.0, resource_restrictions: dict[str, float] | None = None, actor: bool = False, annotations: dict | None = None, @@ -797,7 +793,6 @@ def dummy( who_has=who_has or {}, nbytes=nbytes or {k: 1 for k in who_has or ()}, priority=priority, - duration=duration, run_spec=ComputeTaskEvent.dummy_runspec(key), resource_restrictions=resource_restrictions or {}, actor=actor, @@ -2863,7 +2858,6 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: ts.exception_text = "" ts.traceback_text = "" ts.priority = priority - ts.duration = ev.duration ts.annotations = ev.annotations ts.span_id = ev.span_id From 78256d3d4f82e8e288e1b5d863dc059c43bb8650 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 3 Feb 2025 15:09:02 +0100 Subject: [PATCH 02/10] add test --- distributed/tests/test_steal.py | 51 +++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3976857c9e7..23188a971d2 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -38,6 +38,7 @@ from distributed.utils_test import ( NO_AMM, BlockedGetData, + async_poll_for, captured_logger, freeze_batched_send, gen_cluster, @@ -1877,3 +1878,53 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): await c.gather(futs) events = s.get_events("stealing") assert len(events) == 0 + + +@gen_cluster( + nthreads=[("", 1)], + client=True, + config={"distributed.scheduler.worker-saturation": "inf"}, +) +async def test_do_not_ping_pong(c, s, a): + in_event = Event() + block_event = Event() + + def block(i: int, in_event: Event, block_event: Event) -> int: + in_event.set() + block_event.wait() + return i + + extension = s.extensions["stealing"] + await extension.stop() + + try: + futs = c.map(block, range(20), in_event=in_event, block_event=block_event) + await in_event.wait() + s.task_prefixes["block"].add_exec_time(100) + + async with Worker(s.address, nthreads=1) as b: + try: + await async_poll_for(lambda: s.idle, timeout=5) + + wsB = s.workers[b.address] + + extension.balance() + assert 10 >= len(extension.in_flight) >= 5 + await async_poll_for(lambda: not extension.in_flight, timeout=5) + # On first try, we may try to balance the task executing on a + assert 10 >= len(wsB.processing) >= 5 - 1 + + extension.balance() + # On second try we may want to rebalance a single task if we failed to + # rebalance the task executing on a + assert len(extension.in_flight) <= 1 + await async_poll_for(lambda: not extension.in_flight, timeout=5) + assert 10 >= len(wsB.processing) >= 5 + + # On third try, the balancing should be stable + extension.balance() + assert not extension.in_flight + finally: + await block_event.set() + finally: + await block_event.set() From 5b886564defb44fff7a9cfc883d93103be843049 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 3 Feb 2025 16:03:30 +0100 Subject: [PATCH 03/10] move --- distributed/stealing.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index 6fbdf973260..003d2defc0f 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -542,19 +542,6 @@ def story(self, *keys_or_ts: str | TaskState) -> list: out.append(t) return out - -def _get_thief( - scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState] -) -> WorkerState | None: - valid_workers = scheduler.valid_workers(ts) - if valid_workers is not None: - valid_thieves = potential_thieves & valid_workers - if valid_thieves: - potential_thieves = valid_thieves - elif not ts.loose_restrictions: - return None - return min(potential_thieves, key=partial(scheduler.worker_objective, ts)) - def get_task_duration(self, ts: TaskState) -> float: """Get the estimated computation cost of the given task (not including any communication cost). @@ -580,6 +567,19 @@ def get_task_duration(self, ts: TaskState) -> float: return duration +def _get_thief( + scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState] +) -> WorkerState | None: + valid_workers = scheduler.valid_workers(ts) + if valid_workers is not None: + valid_thieves = potential_thieves & valid_workers + if valid_thieves: + potential_thieves = valid_thieves + elif not ts.loose_restrictions: + return None + return min(potential_thieves, key=partial(scheduler.worker_objective, ts)) + + fast_tasks = { k for k, v in dask.config.get("distributed.scheduler.default-task-durations").items() From 49b916d54205fddc20a3c23160f86472220268b4 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 3 Feb 2025 16:12:57 +0100 Subject: [PATCH 04/10] comments --- distributed/tests/test_steal.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 23188a971d2..9b5cb8ffcec 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1886,6 +1886,9 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): config={"distributed.scheduler.worker-saturation": "inf"}, ) async def test_do_not_ping_pong(c, s, a): + """Regression test that work-stealing does not contihuously move all tasks between + two workers without reaching a stable state, eating up CPU time while doing so. + """ in_event = Event() block_event = Event() @@ -1894,13 +1897,18 @@ def block(i: int, in_event: Event, block_event: Event) -> int: block_event.wait() return i + # Stop stealing for deterministic testing extension = s.extensions["stealing"] await extension.stop() try: futs = c.map(block, range(20), in_event=in_event, block_event=block_event) await in_event.wait() + + # This is the pre-condition for the observed problem: + # There are tasks that execute fox a long time but do not have an average s.task_prefixes["block"].add_exec_time(100) + assert s.task_prefixes["block"].duration_average == -1 async with Worker(s.address, nthreads=1) as b: try: From 54fdd21e3edf26b50b9fa4ad624d34e3bff72868 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 3 Feb 2025 17:38:54 +0100 Subject: [PATCH 05/10] Refactor for maintainability --- distributed/scheduler.py | 33 ++++++++++++++------ distributed/stealing.py | 65 ++++++++++++++++------------------------ 2 files changed, 49 insertions(+), 49 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2eae1e0cec4..58889798652 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1925,6 +1925,29 @@ def total_occupancy(self) -> float: self._network_occ_global, ) + def _get_prefix_duration(self, prefix: TaskPrefix) -> float: + """Get the estimated computation cost of the given task prefix + (not including any communication cost). + + If no data has been observed, value of + `distributed.scheduler.default-task-durations` are used. If none is set + for this task, `distributed.scheduler.unknown-task-duration` is used + instead. + + See Also + -------- + WorkStealing.get_task_duration + """ + # TODO: Deal with unknown tasks better + assert prefix is not None + duration = prefix.duration_average + if duration < 0: + if prefix.max_exec_time > 0: + duration = 2 * prefix.max_exec_time + else: + duration = self.UNKNOWN_TASK_DURATION + return duration + def _calc_occupancy( self, task_prefix_count: dict[str, int], @@ -1932,15 +1955,7 @@ def _calc_occupancy( ) -> float: res = 0.0 for prefix_name, count in task_prefix_count.items(): - # TODO: Deal with unknown tasks better - prefix = self.task_prefixes[prefix_name] - assert prefix is not None - duration = prefix.duration_average - if duration < 0: - if prefix.max_exec_time > 0: - duration = 2 * prefix.max_exec_time - else: - duration = self.UNKNOWN_TASK_DURATION + duration = self._get_prefix_duration(self.task_prefixes[prefix_name]) res += duration * count occ = res + network_occ / self.bandwidth assert occ >= 0, (occ, res, network_occ, self.bandwidth) diff --git a/distributed/stealing.py b/distributed/stealing.py index 003d2defc0f..e06a94534eb 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -234,13 +234,26 @@ def recalculate_cost(self, ts: TaskState) -> None: def put_key_in_stealable(self, ts: TaskState) -> None: cost_multiplier, level = self.steal_time_ratio(ts) - if cost_multiplier is not None: - assert level is not None - assert ts.processing_on - ws = ts.processing_on - worker = ws.address - self.stealable[worker][level].add(ts) - self.key_stealable[ts] = (worker, level) + + prefix = ts.prefix + duration = self.scheduler._get_prefix_duration(prefix) + if cost_multiplier is None: + return + + assert level is not None + assert ts.processing_on + ws = ts.processing_on + worker = ws.address + self.stealable[worker][level].add(ts) + self.key_stealable[ts] = (worker, level) + + if duration == ts.prefix.duration_average: + return + + s = self.unknown_durations.get(prefix.name) + if s is None: + self.unknown_durations[prefix.name] = s = set() + s.add(ts) def remove_key_from_stealable(self, ts: TaskState) -> None: result = self.key_stealable.pop(ts, None) @@ -266,7 +279,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non if not ts.dependencies: # no dependencies fast path return 0, 0 - compute_time = self.get_task_duration(ts) + compute_time = self.scheduler._get_prefix_duration(ts.prefix) if not compute_time: # occupancy/ws.processing[ts] is only allowed to be zero for @@ -312,12 +325,9 @@ def move_task_request( # TODO: occupancy no longer concats linearly so we can't easily # assume that the network cost would go down by that much - victim_duration = self.get_task_duration(ts) + self.scheduler.get_comm_cost( - ts, victim - ) - thief_duration = self.get_task_duration(ts) + self.scheduler.get_comm_cost( - ts, thief - ) + compute = self.scheduler._get_prefix_duration(ts.prefix) + victim_duration = compute + self.scheduler.get_comm_cost(ts, victim) + thief_duration = compute + self.scheduler.get_comm_cost(ts, thief) self.scheduler.stream_comms[victim.address].send( {"op": "steal-request", "key": key, "stimulus_id": stimulus_id} @@ -468,8 +478,7 @@ def balance(self) -> None: occ_victim = self._combined_occupancy(victim) comm_cost_thief = self.scheduler.get_comm_cost(ts, thief) comm_cost_victim = self.scheduler.get_comm_cost(ts, victim) - compute = self.get_task_duration(ts) - + compute = self.scheduler._get_prefix_duration(ts.prefix) if ( occ_thief + comm_cost_thief + compute <= occ_victim - (comm_cost_victim + compute) / 2 @@ -542,30 +551,6 @@ def story(self, *keys_or_ts: str | TaskState) -> list: out.append(t) return out - def get_task_duration(self, ts: TaskState) -> float: - """Get the estimated computation cost of the given task (not including - any communication cost). - - If no data has been observed, value of - `distributed.scheduler.default-task-durations` are used. If none is set - for this task, `distributed.scheduler.unknown-task-duration` is used - instead. - """ - prefix = ts.prefix - duration: float = prefix.duration_average - if duration >= 0: - return duration - if prefix.max_exec_time > 0: - duration = 2 * prefix.max_exec_time - else: - duration = self.scheduler.UNKNOWN_TASK_DURATION - - s = self.unknown_durations.get(prefix.name) - if s is None: - self.unknown_durations[prefix.name] = s = set() - s.add(ts) - return duration - def _get_thief( scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState] From 102dc19cd8a8debe2dada1686aac7aa64ff89dd7 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 3 Feb 2025 17:46:07 +0100 Subject: [PATCH 06/10] Fix tests --- distributed/tests/test_scheduler.py | 21 +++++++++++-------- distributed/tests/test_steal.py | 3 ++- .../tests/test_worker_state_machine.py | 2 -- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index e589ded7937..f0af019e55d 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2788,24 +2788,26 @@ async def test_retire_workers_bad_params(c, s, a, b): @gen_cluster( client=True, config={"distributed.scheduler.default-task-durations": {"inc": 100}} ) -async def test_get_task_duration(c, s, a, b): +async def test_get_prefix_duration(c, s, a, b): future = c.submit(inc, 1) await future assert 10 < s.task_prefixes["inc"].duration_average < 100 ts_pref1 = s.new_task("inc-abcdefab", None, "released") - assert 10 < s.get_task_duration(ts_pref1) < 100 + assert 10 < s._get_prefix_duration(ts_pref1.prefix) < 100 + extension = s.extensions["stealing"] # make sure get_task_duration adds TaskStates to unknown dict - assert len(s.unknown_durations) == 0 + assert len(extension.unknown_durations) == 0 x = c.submit(slowinc, 1, delay=0.5) while len(s.tasks) < 3: await asyncio.sleep(0.01) ts = s.tasks[x.key] - assert s.get_task_duration(ts) == 0.5 # default - assert len(s.unknown_durations) == 1 - assert len(s.unknown_durations["slowinc"]) == 1 + assert s._get_prefix_duration(ts.prefix) == 0.5 # default + + assert len(extension.unknown_durations) == 1 + assert len(extension.unknown_durations["slowinc"]) == 1 @gen_cluster(client=True) @@ -3338,10 +3340,11 @@ async def test_unknown_task_duration_config(client, s, a, b): future = client.submit(slowinc, 1) while not s.tasks: await asyncio.sleep(0.001) - assert sum(s.get_task_duration(ts) for ts in s.tasks.values()) == 3600 - assert len(s.unknown_durations) == 1 + assert sum(s._get_prefix_duration(ts.prefix) for ts in s.tasks.values()) == 3600 + extension = s.extensions["stealing"] + assert len(extension.unknown_durations) == 1 await wait(future) - assert len(s.unknown_durations) == 0 + assert len(extension.unknown_durations) == 0 @gen_cluster() diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 9b5cb8ffcec..ee24fd907bd 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -278,7 +278,8 @@ async def test_eventually_steal_unknown_functions(c, s, a, b): slowinc, range(10), delay=0.1, workers=a.address, allow_other_workers=True ) await wait(futures) - assert not s.unknown_durations + extension = s.extensions["stealing"] + assert not extension.unknown_durations assert len(a.data) >= 3, [len(a.data), len(b.data)] assert len(b.data) >= 3, [len(a.data), len(b.data)] diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 1fe5ec58a07..4d44c7922a6 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -379,7 +379,6 @@ def f(arg): who_has={"y": ["w1"]}, nbytes={"y": 123}, priority=(0,), - duration=123.45, run_spec=(f, "arg", {}), resource_restrictions={}, actor=False, @@ -422,7 +421,6 @@ def test_computetask_dummy(): who_has={}, nbytes={}, priority=(0,), - duration=1.0, run_spec=ComputeTaskEvent.dummy_runspec("x"), resource_restrictions={}, actor=False, From fb8d8eca00c0e351fe59b2ebaf457f9c5db0e173 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 3 Feb 2025 17:50:31 +0100 Subject: [PATCH 07/10] Align adaptive target --- distributed/scheduler.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 58889798652..d7e5aa32c88 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8834,10 +8834,7 @@ def adaptive_target(self, target_duration=None): queued = take(100, concat([self.queued, self.unrunnable.keys()])) queued_occupancy = 0 for ts in queued: - if ts.prefix.duration_average == -1: - queued_occupancy += self.UNKNOWN_TASK_DURATION - else: - queued_occupancy += ts.prefix.duration_average + queued_occupancy += self._get_prefix_duration(ts.prefix) tasks_ready = len(self.queued) + len(self.unrunnable) if tasks_ready > 100: From f1243a3c450f7e496b0f8aaa93af2cf3ec0eb773 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Feb 2025 07:17:52 +0100 Subject: [PATCH 08/10] Trigger CI From 2bd934ee84213b1e0c4bdaa8412d8020e69a641f Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Feb 2025 08:16:10 +0100 Subject: [PATCH 09/10] fix test --- distributed/tests/test_worker_state_machine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 4d44c7922a6..145c211fc0b 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -399,7 +399,6 @@ def f(arg): "nbytes": {"y": 123}, "priority": [0], "run_spec": None, - "duration": 123.45, "resource_restrictions": {}, "actor": False, "annotations": {}, From 40aa62a350c4b0bf9ab73e75ad64cafc53d81f49 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Feb 2025 13:00:35 +0100 Subject: [PATCH 10/10] PR review --- distributed/stealing.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index e06a94534eb..cef98255303 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -235,11 +235,12 @@ def recalculate_cost(self, ts: TaskState) -> None: def put_key_in_stealable(self, ts: TaskState) -> None: cost_multiplier, level = self.steal_time_ratio(ts) - prefix = ts.prefix - duration = self.scheduler._get_prefix_duration(prefix) if cost_multiplier is None: return + prefix = ts.prefix + duration = self.scheduler._get_prefix_duration(prefix) + assert level is not None assert ts.processing_on ws = ts.processing_on @@ -250,10 +251,10 @@ def put_key_in_stealable(self, ts: TaskState) -> None: if duration == ts.prefix.duration_average: return - s = self.unknown_durations.get(prefix.name) - if s is None: - self.unknown_durations[prefix.name] = s = set() - s.add(ts) + if prefix.name not in self.unknown_durations: + self.unknown_durations[prefix.name] = set() + + self.unknown_durations[prefix.name].add(ts) def remove_key_from_stealable(self, ts: TaskState) -> None: result = self.key_stealable.pop(ts, None)