diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e6fb327856..5944587309 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6057,26 +6057,47 @@ def release_worker_data(self, key: Key, worker: str, stimulus_id: str) -> None: self.transitions({key: "released"}, stimulus_id) def handle_long_running( - self, key: Key, worker: str, compute_duration: float | None, stimulus_id: str + self, + key: Key, + worker: str, + run_id: int, + compute_duration: float | None, + stimulus_id: str, ) -> None: """A task has seceded from the thread pool We stop the task from being stolen in the future, and change task duration accounting as if the task has stopped. """ + if worker not in self.workers: + logger.debug( + "Received long-running signal from unknown worker %s. Ignoring.", worker + ) + return + if key not in self.tasks: logger.debug("Skipping long_running since key %s was already released", key) return + ts = self.tasks[key] - steal = self.extensions.get("stealing") - if steal is not None: - steal.remove_key_from_stealable(ts) ws = ts.processing_on if ws is None: logger.debug("Received long-running signal from duplicate task. Ignoring.") return + if ws.address != worker or ts.run_id != run_id: + logger.debug( + "Received stale long-running signal from worker %s for task %s. Ignoring.", + worker, + ts, + ) + return + + steal = self.extensions.get("stealing") + if steal is not None: + steal.remove_key_from_stealable(ts) + if compute_duration is not None: old_duration = ts.prefix.duration_average if old_duration < 0: diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index ece5ef3fdc..a90d526fef 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import pytest @@ -14,6 +15,7 @@ _LockedCommPool, assert_story, async_poll_for, + captured_logger, freeze_batched_send, gen_cluster, inc, @@ -903,13 +905,15 @@ def test_workerstate_executing_to_executing(ws_with_running_task): instructions = ws.handle_stimulus( FreeKeysEvent(keys=["x"], stimulus_id="s1"), - ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s2"), + ComputeTaskEvent.dummy( + "x", run_id=0, resource_restrictions={"R": 1}, stimulus_id="s2" + ), ) if prev_state == "executing": assert not instructions else: assert instructions == [ - LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2") + LongRunningMsg(key="x", run_id=0, compute_duration=None, stimulus_id="s2") ] assert ws.tasks["x"] is ts assert ts.state == prev_state @@ -1087,15 +1091,17 @@ def test_workerstate_resumed_fetch_to_cancelled_to_executing(ws_with_running_tas instructions = ws.handle_stimulus( FreeKeysEvent(keys=["x"], stimulus_id="s1"), - ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"), + ComputeTaskEvent.dummy("y", run_id=0, who_has={"x": [ws2]}, stimulus_id="s2"), FreeKeysEvent(keys=["y", "x"], stimulus_id="s3"), - ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s4"), + ComputeTaskEvent.dummy( + "x", run_id=1, resource_restrictions={"R": 1}, stimulus_id="s4" + ), ) if prev_state == "executing": assert not instructions else: assert instructions == [ - LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4") + LongRunningMsg(key="x", run_id=1, compute_duration=None, stimulus_id="s4") ] assert ws.tasks["x"].state == prev_state @@ -1111,16 +1117,16 @@ def test_workerstate_resumed_fetch_to_executing(ws_with_running_task): # x is released for whatever reason (e.g. client cancellation) FreeKeysEvent(keys=["x"], stimulus_id="s1"), # x was computed somewhere else - ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"), + ComputeTaskEvent.dummy("y", run_id=0, who_has={"x": [ws2]}, stimulus_id="s2"), # x was lost / no known replicas, therefore y is cancelled FreeKeysEvent(keys=["y"], stimulus_id="s3"), - ComputeTaskEvent.dummy("x", stimulus_id="s4"), + ComputeTaskEvent.dummy("x", run_id=1, stimulus_id="s4"), ) if prev_state == "executing": assert not instructions else: assert instructions == [ - LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4") + LongRunningMsg(key="x", run_id=1, compute_duration=None, stimulus_id="s4") ] assert len(ws.tasks) == 1 assert ws.tasks["x"].state == prev_state @@ -1254,12 +1260,14 @@ def test_secede_cancelled_or_resumed_workerstate( """ ws2 = "127.0.0.1:2" ws.handle_stimulus( - ComputeTaskEvent.dummy("x", stimulus_id="s1"), + ComputeTaskEvent.dummy("x", run_id=0, stimulus_id="s1"), FreeKeysEvent(keys=["x"], stimulus_id="s2"), ) if resume_to_fetch: ws.handle_stimulus( - ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s3"), + ComputeTaskEvent.dummy( + "y", run_id=1, who_has={"x": [ws2]}, stimulus_id="s3" + ), ) ts = ws.tasks["x"] assert ts.previous == "executing" @@ -1277,11 +1285,11 @@ def test_secede_cancelled_or_resumed_workerstate( if resume_to_executing: instructions = ws.handle_stimulus( FreeKeysEvent(keys=["y"], stimulus_id="s5"), - ComputeTaskEvent.dummy("x", stimulus_id="s6"), + ComputeTaskEvent.dummy("x", run_id=2, stimulus_id="s6"), ) # Inform the scheduler of the SecedeEvent that happened in the past assert instructions == [ - LongRunningMsg(key="x", compute_duration=None, stimulus_id="s6") + LongRunningMsg(key="x", run_id=2, compute_duration=None, stimulus_id="s6") ] assert ts.state == "long-running" assert ts not in ws.executing @@ -1292,6 +1300,223 @@ def test_secede_cancelled_or_resumed_workerstate( assert ts not in ws.long_running +@gen_cluster(client=True, nthreads=[("", 1), ("", 1)], timeout=2) +async def test_secede_racing_cancellation_and_scheduling_on_other_worker(c, s, a, b): + """Regression test that ensures that we handle stale long-running messages correctly. + + This tests simulates a race condition where a task secedes on worker a, the task is then cancelled, and resubmitted to + run on worker b. The long-running message created on a only arrives on the scheduler after the task started executing on b + (but before a secede event arrives from worker b). The scheduler should then ignore the stale secede event from a. + """ + wsA = s.workers[a.address] + before_secede = Event() + block_secede = Event() + block_long_running = Event() + handled_long_running = Event() + + def f(before_secede, block_secede, block_long_running): + before_secede.set() + block_secede.wait() + distributed.secede() + block_long_running.wait() + return 123 + + # Instrument long-running handler + original_handler = s.stream_handlers["long-running"] + + async def instrumented_handle_long_running(*args, **kwargs): + try: + return original_handler(*args, **kwargs) + finally: + await handled_long_running.set() + + s.stream_handlers["long-running"] = instrumented_handle_long_running + + # Submit task and wait until it executes on a + x = c.submit( + f, + before_secede, + block_secede, + block_long_running, + key="x", + workers=[a.address], + ) + await before_secede.wait() + + # FIXME: Relying on logging is rather brittle. We should fail hard if stimulus handling fails. + with captured_logger("distributed.scheduler", logging.ERROR) as caplog: + with freeze_batched_send(a.batched_stream): + # Let x secede (and later succeed) without informing the scheduler + await block_secede.set() + await wait_for_state("x", "long-running", a) + assert not a.state.executing + assert a.state.long_running + await block_long_running.set() + + await wait_for_state("x", "memory", a) + + # Cancel x while the scheduler does not know that it seceded + x.release() + await async_poll_for(lambda: not s.tasks, timeout=5) + assert not wsA.processing + assert not wsA.long_running + + # Reset all events + await before_secede.clear() + await block_secede.clear() + await block_long_running.clear() + + # Resubmit task and wait until it executes on b + x = c.submit( + f, + before_secede, + block_secede, + block_long_running, + key="x", + workers=[b.address], + ) + await before_secede.wait() + wsB = s.workers[b.address] + assert wsB.processing + assert not wsB.long_running + + # Unblock the stream from a to the scheduler and handle the long-running message + await handled_long_running.wait() + ts = b.state.tasks["x"] + assert ts.state == "executing" + + assert wsB.processing + assert wsB.task_prefix_count + assert not wsB.long_running + + assert not wsA.processing + assert not wsA.task_prefix_count + assert not wsA.long_running + + # Clear the handler and let x secede on b + await handled_long_running.clear() + + await block_secede.set() + await wait_for_state("x", "long-running", b) + + assert not b.state.executing + assert b.state.long_running + await handled_long_running.wait() + + # Assert that the handler did not fail and no state was corrupted + logs = caplog.getvalue() + assert not logs + assert not wsB.task_prefix_count + + await block_long_running.set() + assert await x.result() == 123 + + +@gen_cluster(client=True, nthreads=[("", 1)], timeout=2) +async def test_secede_racing_resuming_on_same_worker(c, s, a): + """Regression test that ensures that we handle stale long-running messages correctly. + + This tests simulates a race condition where a task secedes on worker a, the task is then cancelled, and resumed on + worker a. The first long-running message created on a only arrives on the scheduler after the task was resumed. + The scheduler should then ignore the stale first secede event from a and only handle the second one. + """ + wsA = s.workers[a.address] + before_secede = Event() + block_secede = Event() + block_long_running = Event() + handled_long_running = Event() + block_long_running_handler = Event() + + def f(before_secede, block_secede, block_long_running): + before_secede.set() + block_secede.wait() + distributed.secede() + block_long_running.wait() + return 123 + + # Instrument long-running handler + original_handler = s.stream_handlers["long-running"] + block_second_attempt = None + + async def instrumented_handle_long_running(*args, **kwargs): + nonlocal block_second_attempt + + if block_second_attempt is None: + block_second_attempt = True + elif block_second_attempt is True: + await block_long_running_handler.wait() + block_second_attempt = False + try: + return original_handler(*args, **kwargs) + finally: + await block_long_running_handler.clear() + await handled_long_running.set() + + s.stream_handlers["long-running"] = instrumented_handle_long_running + + # Submit task and wait until it executes on a + x = c.submit( + f, + before_secede, + block_secede, + block_long_running, + key="x", + ) + await before_secede.wait() + + # FIXME: Relying on logging is rather brittle. We should fail hard if stimulus handling fails. + with captured_logger("distributed.scheduler", logging.ERROR) as caplog: + with freeze_batched_send(a.batched_stream): + # Let x secede (and later succeed) without informing the scheduler + await block_secede.set() + await wait_for_state("x", "long-running", a) + assert not a.state.executing + assert a.state.long_running + + # Cancel x while the scheduler does not know that it seceded + x.release() + await async_poll_for(lambda: not s.tasks, timeout=5) + assert not wsA.processing + assert not wsA.long_running + + # Resubmit task and wait until it is resumed on a + x = c.submit( + f, + before_secede, + block_secede, + block_long_running, + key="x", + ) + await wait_for_state("x", "long-running", a) + assert not a.state.executing + assert a.state.long_running + + assert wsA.processing + assert not wsA.long_running + + # Unblock the stream from a to the scheduler and handle the stale long-running message + await handled_long_running.wait() + + assert wsA.processing + assert wsA.task_prefix_count + assert not wsA.long_running + + # Clear the handler and let the scheduler handle the second long-running message + await handled_long_running.clear() + await block_long_running_handler.set() + await handled_long_running.wait() + + # Assert that the handler did not fail and no state was corrupted + logs = caplog.getvalue() + assert not logs + assert not wsA.task_prefix_count + assert wsA.processing + assert wsA.long_running + + await block_long_running.set() + assert await x.result() == 123 + + @gen_cluster(client=True, nthreads=[("", 1)], timeout=2) async def test_secede_cancelled_or_resumed_scheduler(c, s, a): """Same as test_secede_cancelled_or_resumed_workerstate, but testing the interaction diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index c7001c0b00..2470c1c3fb 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -575,11 +575,13 @@ def test_resumed_with_different_resources(ws_with_running_task, done_ev_cls): assert ws.available_resources == {"R": 0} instructions = ws.handle_stimulus( - ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 0.4}) + ComputeTaskEvent.dummy( + "x", run_id=0, stimulus_id="s2", resource_restrictions={"R": 0.4} + ) ) if prev_state == "long-running": assert instructions == [ - LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2") + LongRunningMsg(key="x", run_id=0, compute_duration=None, stimulus_id="s2") ] else: assert not instructions diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 3507fb9379..dc796ffe13 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -509,8 +509,9 @@ class RescheduleMsg(SendMessageToScheduler): class LongRunningMsg(SendMessageToScheduler): op = "long-running" - __slots__ = ("key", "compute_duration") + __slots__ = ("key", "run_id", "compute_duration") key: Key + run_id: int compute_duration: float | None @@ -2171,7 +2172,10 @@ def _transition_cancelled_waiting( ts.state = "long-running" ts.previous = None smsg = LongRunningMsg( - key=ts.key, compute_duration=None, stimulus_id=stimulus_id + key=ts.key, + run_id=ts.run_id, + compute_duration=None, + stimulus_id=stimulus_id, ) return {}, [smsg] else: @@ -2276,7 +2280,10 @@ def _transition_executing_long_running( self.long_running.add(ts) smsg = LongRunningMsg( - key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id + key=ts.key, + run_id=ts.run_id, + compute_duration=compute_duration, + stimulus_id=stimulus_id, ) return merge_recs_instructions( ({}, [smsg]),