Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid handling stale long-running messages on scheduler #8991

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6057,26 +6057,47 @@ def release_worker_data(self, key: Key, worker: str, stimulus_id: str) -> None:
self.transitions({key: "released"}, stimulus_id)

def handle_long_running(
self, key: Key, worker: str, compute_duration: float | None, stimulus_id: str
self,
key: Key,
worker: str,
run_id: int,
compute_duration: float | None,
stimulus_id: str,
) -> None:
"""A task has seceded from the thread pool

We stop the task from being stolen in the future, and change task
duration accounting as if the task has stopped.
"""
if worker not in self.workers:
logger.debug(
"Received long-running signal from unknown worker %s. Ignoring.", worker
)
return

Comment on lines +6072 to +6077
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly for good measure, I think it should the code should also work without this.

if key not in self.tasks:
logger.debug("Skipping long_running since key %s was already released", key)
return

ts = self.tasks[key]
steal = self.extensions.get("stealing")
if steal is not None:
steal.remove_key_from_stealable(ts)

ws = ts.processing_on
if ws is None:
logger.debug("Received long-running signal from duplicate task. Ignoring.")
return

if ws.address != worker or ts.run_id != run_id:
logger.debug(
"Received stale long-running signal from worker %s for task %s. Ignoring.",
worker,
ts,
)
return

steal = self.extensions.get("stealing")
if steal is not None:
steal.remove_key_from_stealable(ts)

Comment on lines +6097 to +6100
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tested the move of this code, but I'm certain that we should deal with staleness before taking any meaningful actions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, absolutely

if compute_duration is not None:
old_duration = ts.prefix.duration_average
if old_duration < 0:
Expand Down
249 changes: 237 additions & 12 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import logging

import pytest

Expand All @@ -14,6 +15,7 @@
_LockedCommPool,
assert_story,
async_poll_for,
captured_logger,
freeze_batched_send,
gen_cluster,
inc,
Expand Down Expand Up @@ -903,13 +905,15 @@ def test_workerstate_executing_to_executing(ws_with_running_task):

instructions = ws.handle_stimulus(
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s2"),
ComputeTaskEvent.dummy(
"x", run_id=0, resource_restrictions={"R": 1}, stimulus_id="s2"
),
)
if prev_state == "executing":
assert not instructions
else:
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2")
LongRunningMsg(key="x", run_id=0, compute_duration=None, stimulus_id="s2")
]
assert ws.tasks["x"] is ts
assert ts.state == prev_state
Expand Down Expand Up @@ -1087,15 +1091,17 @@ def test_workerstate_resumed_fetch_to_cancelled_to_executing(ws_with_running_tas

instructions = ws.handle_stimulus(
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"),
ComputeTaskEvent.dummy("y", run_id=0, who_has={"x": [ws2]}, stimulus_id="s2"),
FreeKeysEvent(keys=["y", "x"], stimulus_id="s3"),
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s4"),
ComputeTaskEvent.dummy(
"x", run_id=1, resource_restrictions={"R": 1}, stimulus_id="s4"
),
)
if prev_state == "executing":
assert not instructions
else:
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4")
LongRunningMsg(key="x", run_id=1, compute_duration=None, stimulus_id="s4")
]
assert ws.tasks["x"].state == prev_state

Expand All @@ -1111,16 +1117,16 @@ def test_workerstate_resumed_fetch_to_executing(ws_with_running_task):
# x is released for whatever reason (e.g. client cancellation)
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
# x was computed somewhere else
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"),
ComputeTaskEvent.dummy("y", run_id=0, who_has={"x": [ws2]}, stimulus_id="s2"),
# x was lost / no known replicas, therefore y is cancelled
FreeKeysEvent(keys=["y"], stimulus_id="s3"),
ComputeTaskEvent.dummy("x", stimulus_id="s4"),
ComputeTaskEvent.dummy("x", run_id=1, stimulus_id="s4"),
)
if prev_state == "executing":
assert not instructions
else:
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4")
LongRunningMsg(key="x", run_id=1, compute_duration=None, stimulus_id="s4")
]
assert len(ws.tasks) == 1
assert ws.tasks["x"].state == prev_state
Expand Down Expand Up @@ -1254,12 +1260,14 @@ def test_secede_cancelled_or_resumed_workerstate(
"""
ws2 = "127.0.0.1:2"
ws.handle_stimulus(
ComputeTaskEvent.dummy("x", stimulus_id="s1"),
ComputeTaskEvent.dummy("x", run_id=0, stimulus_id="s1"),
FreeKeysEvent(keys=["x"], stimulus_id="s2"),
)
if resume_to_fetch:
ws.handle_stimulus(
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s3"),
ComputeTaskEvent.dummy(
"y", run_id=1, who_has={"x": [ws2]}, stimulus_id="s3"
),
)
ts = ws.tasks["x"]
assert ts.previous == "executing"
Expand All @@ -1277,11 +1285,11 @@ def test_secede_cancelled_or_resumed_workerstate(
if resume_to_executing:
instructions = ws.handle_stimulus(
FreeKeysEvent(keys=["y"], stimulus_id="s5"),
ComputeTaskEvent.dummy("x", stimulus_id="s6"),
ComputeTaskEvent.dummy("x", run_id=2, stimulus_id="s6"),
)
# Inform the scheduler of the SecedeEvent that happened in the past
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s6")
LongRunningMsg(key="x", run_id=2, compute_duration=None, stimulus_id="s6")
]
assert ts.state == "long-running"
assert ts not in ws.executing
Expand All @@ -1292,6 +1300,223 @@ def test_secede_cancelled_or_resumed_workerstate(
assert ts not in ws.long_running


@gen_cluster(client=True, nthreads=[("", 1), ("", 1)], timeout=2)
async def test_secede_racing_cancellation_and_scheduling_on_other_worker(c, s, a, b):
"""Regression test that ensures that we handle stale long-running messages correctly.

This tests simulates a race condition where a task secedes on worker a, the task is then cancelled, and resubmitted to
run on worker b. The long-running message created on a only arrives on the scheduler after the task started executing on b
(but before a secede event arrives from worker b). The scheduler should then ignore the stale secede event from a.
"""
wsA = s.workers[a.address]
before_secede = Event()
block_secede = Event()
block_long_running = Event()
handled_long_running = Event()

def f(before_secede, block_secede, block_long_running):
before_secede.set()
block_secede.wait()
distributed.secede()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this also trigger when using worker_client? The secede is an API I typically discourage from using. Mostly because the counterpart rejoin is quite broken

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly suppose it does. The original workload where this popped up had many clients connected to the scheduler.

block_long_running.wait()
return 123

# Instrument long-running handler
original_handler = s.stream_handlers["long-running"]

async def instrumented_handle_long_running(*args, **kwargs):
try:
return original_handler(*args, **kwargs)
finally:
await handled_long_running.set()

s.stream_handlers["long-running"] = instrumented_handle_long_running

# Submit task and wait until it executes on a
x = c.submit(
f,
before_secede,
block_secede,
block_long_running,
key="x",
workers=[a.address],
)
await before_secede.wait()

# FIXME: Relying on logging is rather brittle. We should fail hard if stimulus handling fails.
with captured_logger("distributed.scheduler", logging.ERROR) as caplog:
with freeze_batched_send(a.batched_stream):
# Let x secede (and later succeed) without informing the scheduler
await block_secede.set()
await wait_for_state("x", "long-running", a)
assert not a.state.executing
assert a.state.long_running
await block_long_running.set()

await wait_for_state("x", "memory", a)

# Cancel x while the scheduler does not know that it seceded
x.release()
await async_poll_for(lambda: not s.tasks, timeout=5)
assert not wsA.processing
assert not wsA.long_running

# Reset all events
await before_secede.clear()
await block_secede.clear()
await block_long_running.clear()

# Resubmit task and wait until it executes on b
x = c.submit(
f,
before_secede,
block_secede,
block_long_running,
key="x",
workers=[b.address],
)
await before_secede.wait()
wsB = s.workers[b.address]
assert wsB.processing
assert not wsB.long_running

# Unblock the stream from a to the scheduler and handle the long-running message
await handled_long_running.wait()
ts = b.state.tasks["x"]
assert ts.state == "executing"

assert wsB.processing
assert wsB.task_prefix_count
assert not wsB.long_running

assert not wsA.processing
assert not wsA.task_prefix_count
assert not wsA.long_running

# Clear the handler and let x secede on b
await handled_long_running.clear()

await block_secede.set()
await wait_for_state("x", "long-running", b)

assert not b.state.executing
assert b.state.long_running
await handled_long_running.wait()

# Assert that the handler did not fail and no state was corrupted
logs = caplog.getvalue()
assert not logs
assert not wsB.task_prefix_count
Comment on lines +1406 to +1409
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer a test that does not rely on logging. Is this corruption detectable with validate? (If not, can it be made detectable with this?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, let me check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to work out of the box. We'd either have ti o log (or hard-fail) on errors in the stimulus or validate that scheduler and worker state don't drift apart.


await block_long_running.set()
assert await x.result() == 123


@gen_cluster(client=True, nthreads=[("", 1)], timeout=2)
async def test_secede_racing_resuming_on_same_worker(c, s, a):
"""Regression test that ensures that we handle stale long-running messages correctly.

This tests simulates a race condition where a task secedes on worker a, the task is then cancelled, and resumed on
worker a. The first long-running message created on a only arrives on the scheduler after the task was resumed.
The scheduler should then ignore the stale first secede event from a and only handle the second one.
"""
wsA = s.workers[a.address]
before_secede = Event()
block_secede = Event()
block_long_running = Event()
handled_long_running = Event()
block_long_running_handler = Event()

def f(before_secede, block_secede, block_long_running):
before_secede.set()
block_secede.wait()
distributed.secede()
block_long_running.wait()
return 123

# Instrument long-running handler
original_handler = s.stream_handlers["long-running"]
block_second_attempt = None

async def instrumented_handle_long_running(*args, **kwargs):
nonlocal block_second_attempt

if block_second_attempt is None:
block_second_attempt = True
elif block_second_attempt is True:
await block_long_running_handler.wait()
block_second_attempt = False
try:
return original_handler(*args, **kwargs)
finally:
await block_long_running_handler.clear()
await handled_long_running.set()

s.stream_handlers["long-running"] = instrumented_handle_long_running

# Submit task and wait until it executes on a
x = c.submit(
f,
before_secede,
block_secede,
block_long_running,
key="x",
)
await before_secede.wait()

# FIXME: Relying on logging is rather brittle. We should fail hard if stimulus handling fails.
with captured_logger("distributed.scheduler", logging.ERROR) as caplog:
with freeze_batched_send(a.batched_stream):
# Let x secede (and later succeed) without informing the scheduler
await block_secede.set()
await wait_for_state("x", "long-running", a)
assert not a.state.executing
assert a.state.long_running

# Cancel x while the scheduler does not know that it seceded
x.release()
await async_poll_for(lambda: not s.tasks, timeout=5)
assert not wsA.processing
assert not wsA.long_running

# Resubmit task and wait until it is resumed on a
x = c.submit(
f,
before_secede,
block_secede,
block_long_running,
key="x",
)
await wait_for_state("x", "long-running", a)
assert not a.state.executing
assert a.state.long_running

assert wsA.processing
assert not wsA.long_running

# Unblock the stream from a to the scheduler and handle the stale long-running message
await handled_long_running.wait()

assert wsA.processing
assert wsA.task_prefix_count
assert not wsA.long_running

# Clear the handler and let the scheduler handle the second long-running message
await handled_long_running.clear()
await block_long_running_handler.set()
await handled_long_running.wait()

# Assert that the handler did not fail and no state was corrupted
logs = caplog.getvalue()
assert not logs
assert not wsA.task_prefix_count
assert wsA.processing
assert wsA.long_running

await block_long_running.set()
assert await x.result() == 123


@gen_cluster(client=True, nthreads=[("", 1)], timeout=2)
async def test_secede_cancelled_or_resumed_scheduler(c, s, a):
"""Same as test_secede_cancelled_or_resumed_workerstate, but testing the interaction
Expand Down
6 changes: 4 additions & 2 deletions distributed/tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,11 +575,13 @@ def test_resumed_with_different_resources(ws_with_running_task, done_ev_cls):
assert ws.available_resources == {"R": 0}

instructions = ws.handle_stimulus(
ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 0.4})
ComputeTaskEvent.dummy(
"x", run_id=0, stimulus_id="s2", resource_restrictions={"R": 0.4}
)
)
if prev_state == "long-running":
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2")
LongRunningMsg(key="x", run_id=0, compute_duration=None, stimulus_id="s2")
]
else:
assert not instructions
Expand Down
Loading
Loading