diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e6fb327856..d7e5aa32c8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1674,9 +1674,6 @@ class SchedulerState: #: Subset of tasks that exist in memory on more than one worker replicated_tasks: set[TaskState] - #: Tasks with unknown duration, grouped by prefix - #: {task prefix: {ts, ts, ...}} - unknown_durations: dict[str, set[TaskState]] task_groups: dict[str, TaskGroup] task_prefixes: dict[str, TaskPrefix] task_metadata: dict[Key, Any] @@ -1776,7 +1773,6 @@ def __init__( self.task_metadata = {} self.total_nthreads = 0 self.total_nthreads_history = [(time(), 0)] - self.unknown_durations = {} self.queued = queued self.unrunnable = unrunnable self.validate = validate @@ -1855,7 +1851,6 @@ def __pdict__(self) -> dict[str, Any]: "unrunnable": self.unrunnable, "queued": self.queued, "n_tasks": self.n_tasks, - "unknown_durations": self.unknown_durations, "validate": self.validate, "tasks": self.tasks, "task_groups": self.task_groups, @@ -1907,7 +1902,6 @@ def _clear_task_state(self) -> None: self.task_prefixes, self.task_groups, self.task_metadata, - self.unknown_durations, self.replicated_tasks, ): collection.clear() @@ -1931,6 +1925,29 @@ def total_occupancy(self) -> float: self._network_occ_global, ) + def _get_prefix_duration(self, prefix: TaskPrefix) -> float: + """Get the estimated computation cost of the given task prefix + (not including any communication cost). + + If no data has been observed, value of + `distributed.scheduler.default-task-durations` are used. If none is set + for this task, `distributed.scheduler.unknown-task-duration` is used + instead. + + See Also + -------- + WorkStealing.get_task_duration + """ + # TODO: Deal with unknown tasks better + assert prefix is not None + duration = prefix.duration_average + if duration < 0: + if prefix.max_exec_time > 0: + duration = 2 * prefix.max_exec_time + else: + duration = self.UNKNOWN_TASK_DURATION + return duration + def _calc_occupancy( self, task_prefix_count: dict[str, int], @@ -1938,15 +1955,7 @@ def _calc_occupancy( ) -> float: res = 0.0 for prefix_name, count in task_prefix_count.items(): - # TODO: Deal with unknown tasks better - prefix = self.task_prefixes[prefix_name] - assert prefix is not None - duration = prefix.duration_average - if duration < 0: - if prefix.max_exec_time > 0: - duration = 2 * prefix.max_exec_time - else: - duration = self.UNKNOWN_TASK_DURATION + duration = self._get_prefix_duration(self.task_prefixes[prefix_name]) res += duration * count occ = res + network_occ / self.bandwidth assert occ >= 0, (occ, res, network_occ, self.bandwidth) @@ -2536,13 +2545,6 @@ def _transition_processing_memory( action=startstop["action"], ) - s = self.unknown_durations.pop(ts.prefix.name, set()) - steal = self.extensions.get("stealing") - if steal: - for tts in s: - if tts.processing_on: - steal.recalculate_cost(tts) - ############################ # Update State Information # ############################ @@ -3171,26 +3173,6 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: nbytes = sum(dts.nbytes for dts in deps) return nbytes / self.bandwidth - def get_task_duration(self, ts: TaskState) -> float: - """Get the estimated computation cost of the given task (not including - any communication cost). - - If no data has been observed, value of - `distributed.scheduler.default-task-durations` are used. If none is set - for this task, `distributed.scheduler.unknown-task-duration` is used - instead. - """ - prefix = ts.prefix - duration: float = prefix.duration_average - if duration >= 0: - return duration - - s = self.unknown_durations.get(prefix.name) - if s is None: - self.unknown_durations[prefix.name] = s = set() - s.add(ts) - return self.UNKNOWN_TASK_DURATION - def valid_workers(self, ts: TaskState) -> set[WorkerState] | None: """Return set of currently valid workers for key @@ -3569,12 +3551,8 @@ def _client_releases_keys( elif ts.state != "erred" and not ts.waiters: recommendations[ts.key] = "released" - def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: + def _task_to_msg(self, ts: TaskState) -> dict[str, Any]: """Convert a single computational task to a message""" - # FIXME: The duration attribute is not used on worker. We could save ourselves the - # time to compute and submit this - if duration < 0: - duration = self.get_task_duration(ts) ts.run_id = next(TaskState._run_id_iterator) assert ts.priority, ts msg: dict[str, Any] = { @@ -3582,7 +3560,6 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: "key": ts.key, "run_id": ts.run_id, "priority": ts.priority, - "duration": duration, "stimulus_id": f"compute-task-{time()}", "who_has": { dts.key: tuple(ws.address for ws in (dts.who_has or ())) @@ -6003,12 +5980,10 @@ async def remove_client_from_events() -> None: cleanup_delay, remove_client_from_events ) - def send_task_to_worker( - self, worker: str, ts: TaskState, duration: float = -1 - ) -> None: + def send_task_to_worker(self, worker: str, ts: TaskState) -> None: """Send a single computational task to a worker""" try: - msg = self._task_to_msg(ts, duration) + msg = self._task_to_msg(ts) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -8859,10 +8834,7 @@ def adaptive_target(self, target_duration=None): queued = take(100, concat([self.queued, self.unrunnable.keys()])) queued_occupancy = 0 for ts in queued: - if ts.prefix.duration_average == -1: - queued_occupancy += self.UNKNOWN_TASK_DURATION - else: - queued_occupancy += ts.prefix.duration_average + queued_occupancy += self._get_prefix_duration(ts.prefix) tasks_ready = len(self.queued) + len(self.unrunnable) if tasks_ready > 100: diff --git a/distributed/stealing.py b/distributed/stealing.py index 7e3711b8e2..cef9825530 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -88,6 +88,9 @@ class WorkStealing(SchedulerPlugin): metrics: dict[str, dict[int, float]] _in_flight_event: asyncio.Event _request_counter: int + #: Tasks with unknown duration, grouped by prefix + #: {task prefix: {ts, ts, ...}} + unknown_durations: dict[str, set[TaskState]] def __init__(self, scheduler: Scheduler): self.scheduler = scheduler @@ -111,6 +114,7 @@ def __init__(self, scheduler: Scheduler): self.in_flight_occupancy = defaultdict(int) self.in_flight_tasks = defaultdict(int) self._in_flight_event = asyncio.Event() + self.unknown_durations = {} self.metrics = { "request_count_total": defaultdict(int), "request_cost_total": defaultdict(int), @@ -188,6 +192,13 @@ def transition( ts = self.scheduler.tasks[key] self.remove_key_from_stealable(ts) self._remove_from_in_flight(ts) + + if finish == "memory": + s = self.unknown_durations.pop(ts.prefix.name, set()) + for tts in s: + if tts.processing_on: + self.recalculate_cost(tts) + if finish == "processing": ts = self.scheduler.tasks[key] self.put_key_in_stealable(ts) @@ -223,13 +234,27 @@ def recalculate_cost(self, ts: TaskState) -> None: def put_key_in_stealable(self, ts: TaskState) -> None: cost_multiplier, level = self.steal_time_ratio(ts) - if cost_multiplier is not None: - assert level is not None - assert ts.processing_on - ws = ts.processing_on - worker = ws.address - self.stealable[worker][level].add(ts) - self.key_stealable[ts] = (worker, level) + + if cost_multiplier is None: + return + + prefix = ts.prefix + duration = self.scheduler._get_prefix_duration(prefix) + + assert level is not None + assert ts.processing_on + ws = ts.processing_on + worker = ws.address + self.stealable[worker][level].add(ts) + self.key_stealable[ts] = (worker, level) + + if duration == ts.prefix.duration_average: + return + + if prefix.name not in self.unknown_durations: + self.unknown_durations[prefix.name] = set() + + self.unknown_durations[prefix.name].add(ts) def remove_key_from_stealable(self, ts: TaskState) -> None: result = self.key_stealable.pop(ts, None) @@ -255,7 +280,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non if not ts.dependencies: # no dependencies fast path return 0, 0 - compute_time = self.scheduler.get_task_duration(ts) + compute_time = self.scheduler._get_prefix_duration(ts.prefix) if not compute_time: # occupancy/ws.processing[ts] is only allowed to be zero for @@ -301,12 +326,9 @@ def move_task_request( # TODO: occupancy no longer concats linearly so we can't easily # assume that the network cost would go down by that much - victim_duration = self.scheduler.get_task_duration( - ts - ) + self.scheduler.get_comm_cost(ts, victim) - thief_duration = self.scheduler.get_task_duration( - ts - ) + self.scheduler.get_comm_cost(ts, thief) + compute = self.scheduler._get_prefix_duration(ts.prefix) + victim_duration = compute + self.scheduler.get_comm_cost(ts, victim) + thief_duration = compute + self.scheduler.get_comm_cost(ts, thief) self.scheduler.stream_comms[victim.address].send( {"op": "steal-request", "key": key, "stimulus_id": stimulus_id} @@ -457,8 +479,7 @@ def balance(self) -> None: occ_victim = self._combined_occupancy(victim) comm_cost_thief = self.scheduler.get_comm_cost(ts, thief) comm_cost_victim = self.scheduler.get_comm_cost(ts, victim) - compute = self.scheduler.get_task_duration(ts) - + compute = self.scheduler._get_prefix_duration(ts.prefix) if ( occ_thief + comm_cost_thief + compute <= occ_victim - (comm_cost_victim + compute) / 2 @@ -483,6 +504,8 @@ def balance(self) -> None: occ_thief = self._combined_occupancy(thief) nproc_thief = self._combined_nprocessing(thief) + # FIXME: In the worst case, the victim may have 3x the amount of work + # of the thief when this aborts balancing. if not self.scheduler.is_unoccupied( thief, occ_thief, nproc_thief ): @@ -514,6 +537,7 @@ def restart(self, scheduler: Any) -> None: s.clear() self.key_stealable.clear() + self.unknown_durations.clear() def story(self, *keys_or_ts: str | TaskState) -> list: keys = {key.key if not isinstance(key, str) else key for key in keys_or_ts} diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index e589ded793..f0af019e55 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2788,24 +2788,26 @@ async def test_retire_workers_bad_params(c, s, a, b): @gen_cluster( client=True, config={"distributed.scheduler.default-task-durations": {"inc": 100}} ) -async def test_get_task_duration(c, s, a, b): +async def test_get_prefix_duration(c, s, a, b): future = c.submit(inc, 1) await future assert 10 < s.task_prefixes["inc"].duration_average < 100 ts_pref1 = s.new_task("inc-abcdefab", None, "released") - assert 10 < s.get_task_duration(ts_pref1) < 100 + assert 10 < s._get_prefix_duration(ts_pref1.prefix) < 100 + extension = s.extensions["stealing"] # make sure get_task_duration adds TaskStates to unknown dict - assert len(s.unknown_durations) == 0 + assert len(extension.unknown_durations) == 0 x = c.submit(slowinc, 1, delay=0.5) while len(s.tasks) < 3: await asyncio.sleep(0.01) ts = s.tasks[x.key] - assert s.get_task_duration(ts) == 0.5 # default - assert len(s.unknown_durations) == 1 - assert len(s.unknown_durations["slowinc"]) == 1 + assert s._get_prefix_duration(ts.prefix) == 0.5 # default + + assert len(extension.unknown_durations) == 1 + assert len(extension.unknown_durations["slowinc"]) == 1 @gen_cluster(client=True) @@ -3338,10 +3340,11 @@ async def test_unknown_task_duration_config(client, s, a, b): future = client.submit(slowinc, 1) while not s.tasks: await asyncio.sleep(0.001) - assert sum(s.get_task_duration(ts) for ts in s.tasks.values()) == 3600 - assert len(s.unknown_durations) == 1 + assert sum(s._get_prefix_duration(ts.prefix) for ts in s.tasks.values()) == 3600 + extension = s.extensions["stealing"] + assert len(extension.unknown_durations) == 1 await wait(future) - assert len(s.unknown_durations) == 0 + assert len(extension.unknown_durations) == 0 @gen_cluster() diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3976857c9e..ee24fd907b 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -38,6 +38,7 @@ from distributed.utils_test import ( NO_AMM, BlockedGetData, + async_poll_for, captured_logger, freeze_batched_send, gen_cluster, @@ -277,7 +278,8 @@ async def test_eventually_steal_unknown_functions(c, s, a, b): slowinc, range(10), delay=0.1, workers=a.address, allow_other_workers=True ) await wait(futures) - assert not s.unknown_durations + extension = s.extensions["stealing"] + assert not extension.unknown_durations assert len(a.data) >= 3, [len(a.data), len(b.data)] assert len(b.data) >= 3, [len(a.data), len(b.data)] @@ -1877,3 +1879,61 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): await c.gather(futs) events = s.get_events("stealing") assert len(events) == 0 + + +@gen_cluster( + nthreads=[("", 1)], + client=True, + config={"distributed.scheduler.worker-saturation": "inf"}, +) +async def test_do_not_ping_pong(c, s, a): + """Regression test that work-stealing does not contihuously move all tasks between + two workers without reaching a stable state, eating up CPU time while doing so. + """ + in_event = Event() + block_event = Event() + + def block(i: int, in_event: Event, block_event: Event) -> int: + in_event.set() + block_event.wait() + return i + + # Stop stealing for deterministic testing + extension = s.extensions["stealing"] + await extension.stop() + + try: + futs = c.map(block, range(20), in_event=in_event, block_event=block_event) + await in_event.wait() + + # This is the pre-condition for the observed problem: + # There are tasks that execute fox a long time but do not have an average + s.task_prefixes["block"].add_exec_time(100) + assert s.task_prefixes["block"].duration_average == -1 + + async with Worker(s.address, nthreads=1) as b: + try: + await async_poll_for(lambda: s.idle, timeout=5) + + wsB = s.workers[b.address] + + extension.balance() + assert 10 >= len(extension.in_flight) >= 5 + await async_poll_for(lambda: not extension.in_flight, timeout=5) + # On first try, we may try to balance the task executing on a + assert 10 >= len(wsB.processing) >= 5 - 1 + + extension.balance() + # On second try we may want to rebalance a single task if we failed to + # rebalance the task executing on a + assert len(extension.in_flight) <= 1 + await async_poll_for(lambda: not extension.in_flight, timeout=5) + assert 10 >= len(wsB.processing) >= 5 + + # On third try, the balancing should be stable + extension.balance() + assert not extension.in_flight + finally: + await block_event.set() + finally: + await block_event.set() diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 1fe5ec58a0..145c211fc0 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -379,7 +379,6 @@ def f(arg): who_has={"y": ["w1"]}, nbytes={"y": 123}, priority=(0,), - duration=123.45, run_spec=(f, "arg", {}), resource_restrictions={}, actor=False, @@ -400,7 +399,6 @@ def f(arg): "nbytes": {"y": 123}, "priority": [0], "run_spec": None, - "duration": 123.45, "resource_restrictions": {}, "actor": False, "annotations": {}, @@ -422,7 +420,6 @@ def test_computetask_dummy(): who_has={}, nbytes={}, priority=(0,), - duration=1.0, run_spec=ComputeTaskEvent.dummy_runspec("x"), resource_restrictions={}, actor=False, diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 3507fb9379..5e34acd85e 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -236,8 +236,6 @@ class TaskState: #: The next state of the task. It is not None iff :attr:`state` == resumed. next: Literal["fetch", "waiting", None] = None - #: Expected duration of the task - duration: float | None = None #: The priority this task given by the scheduler. Determines run order. priority: tuple[int, ...] | None = None #: Addresses of workers that we believe have this data @@ -736,7 +734,6 @@ class ComputeTaskEvent(StateMachineEvent): who_has: dict[Key, Collection[str]] nbytes: dict[Key, int] priority: tuple[int, ...] - duration: float run_spec: T_runspec | None resource_restrictions: dict[str, float] actor: bool @@ -782,7 +779,6 @@ def dummy( who_has: dict[Key, Collection[str]] | None = None, nbytes: dict[Key, int] | None = None, priority: tuple[int, ...] = (0,), - duration: float = 1.0, resource_restrictions: dict[str, float] | None = None, actor: bool = False, annotations: dict | None = None, @@ -797,7 +793,6 @@ def dummy( who_has=who_has or {}, nbytes=nbytes or {k: 1 for k in who_has or ()}, priority=priority, - duration=duration, run_spec=ComputeTaskEvent.dummy_runspec(key), resource_restrictions=resource_restrictions or {}, actor=actor, @@ -2863,7 +2858,6 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: ts.exception_text = "" ts.traceback_text = "" ts.priority = priority - ts.duration = ev.duration ts.annotations = ev.annotations ts.span_id = ev.span_id