Skip to content

Commit 55bb639

Browse files
Avoid handling stale long-running messages on scheduler (#8991)
1 parent de18eda commit 55bb639

File tree

4 files changed

+276
-21
lines changed

4 files changed

+276
-21
lines changed

distributed/scheduler.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -6032,26 +6032,47 @@ def release_worker_data(self, key: Key, worker: str, stimulus_id: str) -> None:
60326032
self.transitions({key: "released"}, stimulus_id)
60336033

60346034
def handle_long_running(
6035-
self, key: Key, worker: str, compute_duration: float | None, stimulus_id: str
6035+
self,
6036+
key: Key,
6037+
worker: str,
6038+
run_id: int,
6039+
compute_duration: float | None,
6040+
stimulus_id: str,
60366041
) -> None:
60376042
"""A task has seceded from the thread pool
60386043
60396044
We stop the task from being stolen in the future, and change task
60406045
duration accounting as if the task has stopped.
60416046
"""
6047+
if worker not in self.workers:
6048+
logger.debug(
6049+
"Received long-running signal from unknown worker %s. Ignoring.", worker
6050+
)
6051+
return
6052+
60426053
if key not in self.tasks:
60436054
logger.debug("Skipping long_running since key %s was already released", key)
60446055
return
6056+
60456057
ts = self.tasks[key]
6046-
steal = self.extensions.get("stealing")
6047-
if steal is not None:
6048-
steal.remove_key_from_stealable(ts)
60496058

60506059
ws = ts.processing_on
60516060
if ws is None:
60526061
logger.debug("Received long-running signal from duplicate task. Ignoring.")
60536062
return
60546063

6064+
if ws.address != worker or ts.run_id != run_id:
6065+
logger.debug(
6066+
"Received stale long-running signal from worker %s for task %s. Ignoring.",
6067+
worker,
6068+
ts,
6069+
)
6070+
return
6071+
6072+
steal = self.extensions.get("stealing")
6073+
if steal is not None:
6074+
steal.remove_key_from_stealable(ts)
6075+
60556076
if compute_duration is not None:
60566077
old_duration = ts.prefix.duration_average
60576078
if old_duration < 0:

distributed/tests/test_cancelled_state.py

+237-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import logging
45

56
import pytest
67

@@ -14,6 +15,7 @@
1415
_LockedCommPool,
1516
assert_story,
1617
async_poll_for,
18+
captured_logger,
1719
freeze_batched_send,
1820
gen_cluster,
1921
inc,
@@ -903,13 +905,15 @@ def test_workerstate_executing_to_executing(ws_with_running_task):
903905

904906
instructions = ws.handle_stimulus(
905907
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
906-
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s2"),
908+
ComputeTaskEvent.dummy(
909+
"x", run_id=0, resource_restrictions={"R": 1}, stimulus_id="s2"
910+
),
907911
)
908912
if prev_state == "executing":
909913
assert not instructions
910914
else:
911915
assert instructions == [
912-
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2")
916+
LongRunningMsg(key="x", run_id=0, compute_duration=None, stimulus_id="s2")
913917
]
914918
assert ws.tasks["x"] is ts
915919
assert ts.state == prev_state
@@ -1087,15 +1091,17 @@ def test_workerstate_resumed_fetch_to_cancelled_to_executing(ws_with_running_tas
10871091

10881092
instructions = ws.handle_stimulus(
10891093
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
1090-
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"),
1094+
ComputeTaskEvent.dummy("y", run_id=0, who_has={"x": [ws2]}, stimulus_id="s2"),
10911095
FreeKeysEvent(keys=["y", "x"], stimulus_id="s3"),
1092-
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s4"),
1096+
ComputeTaskEvent.dummy(
1097+
"x", run_id=1, resource_restrictions={"R": 1}, stimulus_id="s4"
1098+
),
10931099
)
10941100
if prev_state == "executing":
10951101
assert not instructions
10961102
else:
10971103
assert instructions == [
1098-
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4")
1104+
LongRunningMsg(key="x", run_id=1, compute_duration=None, stimulus_id="s4")
10991105
]
11001106
assert ws.tasks["x"].state == prev_state
11011107

@@ -1111,16 +1117,16 @@ def test_workerstate_resumed_fetch_to_executing(ws_with_running_task):
11111117
# x is released for whatever reason (e.g. client cancellation)
11121118
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
11131119
# x was computed somewhere else
1114-
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"),
1120+
ComputeTaskEvent.dummy("y", run_id=0, who_has={"x": [ws2]}, stimulus_id="s2"),
11151121
# x was lost / no known replicas, therefore y is cancelled
11161122
FreeKeysEvent(keys=["y"], stimulus_id="s3"),
1117-
ComputeTaskEvent.dummy("x", stimulus_id="s4"),
1123+
ComputeTaskEvent.dummy("x", run_id=1, stimulus_id="s4"),
11181124
)
11191125
if prev_state == "executing":
11201126
assert not instructions
11211127
else:
11221128
assert instructions == [
1123-
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4")
1129+
LongRunningMsg(key="x", run_id=1, compute_duration=None, stimulus_id="s4")
11241130
]
11251131
assert len(ws.tasks) == 1
11261132
assert ws.tasks["x"].state == prev_state
@@ -1254,12 +1260,14 @@ def test_secede_cancelled_or_resumed_workerstate(
12541260
"""
12551261
ws2 = "127.0.0.1:2"
12561262
ws.handle_stimulus(
1257-
ComputeTaskEvent.dummy("x", stimulus_id="s1"),
1263+
ComputeTaskEvent.dummy("x", run_id=0, stimulus_id="s1"),
12581264
FreeKeysEvent(keys=["x"], stimulus_id="s2"),
12591265
)
12601266
if resume_to_fetch:
12611267
ws.handle_stimulus(
1262-
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s3"),
1268+
ComputeTaskEvent.dummy(
1269+
"y", run_id=1, who_has={"x": [ws2]}, stimulus_id="s3"
1270+
),
12631271
)
12641272
ts = ws.tasks["x"]
12651273
assert ts.previous == "executing"
@@ -1277,11 +1285,11 @@ def test_secede_cancelled_or_resumed_workerstate(
12771285
if resume_to_executing:
12781286
instructions = ws.handle_stimulus(
12791287
FreeKeysEvent(keys=["y"], stimulus_id="s5"),
1280-
ComputeTaskEvent.dummy("x", stimulus_id="s6"),
1288+
ComputeTaskEvent.dummy("x", run_id=2, stimulus_id="s6"),
12811289
)
12821290
# Inform the scheduler of the SecedeEvent that happened in the past
12831291
assert instructions == [
1284-
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s6")
1292+
LongRunningMsg(key="x", run_id=2, compute_duration=None, stimulus_id="s6")
12851293
]
12861294
assert ts.state == "long-running"
12871295
assert ts not in ws.executing
@@ -1292,6 +1300,223 @@ def test_secede_cancelled_or_resumed_workerstate(
12921300
assert ts not in ws.long_running
12931301

12941302

1303+
@gen_cluster(client=True, nthreads=[("", 1), ("", 1)], timeout=2)
1304+
async def test_secede_racing_cancellation_and_scheduling_on_other_worker(c, s, a, b):
1305+
"""Regression test that ensures that we handle stale long-running messages correctly.
1306+
1307+
This tests simulates a race condition where a task secedes on worker a, the task is then cancelled, and resubmitted to
1308+
run on worker b. The long-running message created on a only arrives on the scheduler after the task started executing on b
1309+
(but before a secede event arrives from worker b). The scheduler should then ignore the stale secede event from a.
1310+
"""
1311+
wsA = s.workers[a.address]
1312+
before_secede = Event()
1313+
block_secede = Event()
1314+
block_long_running = Event()
1315+
handled_long_running = Event()
1316+
1317+
def f(before_secede, block_secede, block_long_running):
1318+
before_secede.set()
1319+
block_secede.wait()
1320+
distributed.secede()
1321+
block_long_running.wait()
1322+
return 123
1323+
1324+
# Instrument long-running handler
1325+
original_handler = s.stream_handlers["long-running"]
1326+
1327+
async def instrumented_handle_long_running(*args, **kwargs):
1328+
try:
1329+
return original_handler(*args, **kwargs)
1330+
finally:
1331+
await handled_long_running.set()
1332+
1333+
s.stream_handlers["long-running"] = instrumented_handle_long_running
1334+
1335+
# Submit task and wait until it executes on a
1336+
x = c.submit(
1337+
f,
1338+
before_secede,
1339+
block_secede,
1340+
block_long_running,
1341+
key="x",
1342+
workers=[a.address],
1343+
)
1344+
await before_secede.wait()
1345+
1346+
# FIXME: Relying on logging is rather brittle. We should fail hard if stimulus handling fails.
1347+
with captured_logger("distributed.scheduler", logging.ERROR) as caplog:
1348+
with freeze_batched_send(a.batched_stream):
1349+
# Let x secede (and later succeed) without informing the scheduler
1350+
await block_secede.set()
1351+
await wait_for_state("x", "long-running", a)
1352+
assert not a.state.executing
1353+
assert a.state.long_running
1354+
await block_long_running.set()
1355+
1356+
await wait_for_state("x", "memory", a)
1357+
1358+
# Cancel x while the scheduler does not know that it seceded
1359+
x.release()
1360+
await async_poll_for(lambda: not s.tasks, timeout=5)
1361+
assert not wsA.processing
1362+
assert not wsA.long_running
1363+
1364+
# Reset all events
1365+
await before_secede.clear()
1366+
await block_secede.clear()
1367+
await block_long_running.clear()
1368+
1369+
# Resubmit task and wait until it executes on b
1370+
x = c.submit(
1371+
f,
1372+
before_secede,
1373+
block_secede,
1374+
block_long_running,
1375+
key="x",
1376+
workers=[b.address],
1377+
)
1378+
await before_secede.wait()
1379+
wsB = s.workers[b.address]
1380+
assert wsB.processing
1381+
assert not wsB.long_running
1382+
1383+
# Unblock the stream from a to the scheduler and handle the long-running message
1384+
await handled_long_running.wait()
1385+
ts = b.state.tasks["x"]
1386+
assert ts.state == "executing"
1387+
1388+
assert wsB.processing
1389+
assert wsB.task_prefix_count
1390+
assert not wsB.long_running
1391+
1392+
assert not wsA.processing
1393+
assert not wsA.task_prefix_count
1394+
assert not wsA.long_running
1395+
1396+
# Clear the handler and let x secede on b
1397+
await handled_long_running.clear()
1398+
1399+
await block_secede.set()
1400+
await wait_for_state("x", "long-running", b)
1401+
1402+
assert not b.state.executing
1403+
assert b.state.long_running
1404+
await handled_long_running.wait()
1405+
1406+
# Assert that the handler did not fail and no state was corrupted
1407+
logs = caplog.getvalue()
1408+
assert not logs
1409+
assert not wsB.task_prefix_count
1410+
1411+
await block_long_running.set()
1412+
assert await x.result() == 123
1413+
1414+
1415+
@gen_cluster(client=True, nthreads=[("", 1)], timeout=2)
1416+
async def test_secede_racing_resuming_on_same_worker(c, s, a):
1417+
"""Regression test that ensures that we handle stale long-running messages correctly.
1418+
1419+
This tests simulates a race condition where a task secedes on worker a, the task is then cancelled, and resumed on
1420+
worker a. The first long-running message created on a only arrives on the scheduler after the task was resumed.
1421+
The scheduler should then ignore the stale first secede event from a and only handle the second one.
1422+
"""
1423+
wsA = s.workers[a.address]
1424+
before_secede = Event()
1425+
block_secede = Event()
1426+
block_long_running = Event()
1427+
handled_long_running = Event()
1428+
block_long_running_handler = Event()
1429+
1430+
def f(before_secede, block_secede, block_long_running):
1431+
before_secede.set()
1432+
block_secede.wait()
1433+
distributed.secede()
1434+
block_long_running.wait()
1435+
return 123
1436+
1437+
# Instrument long-running handler
1438+
original_handler = s.stream_handlers["long-running"]
1439+
block_second_attempt = None
1440+
1441+
async def instrumented_handle_long_running(*args, **kwargs):
1442+
nonlocal block_second_attempt
1443+
1444+
if block_second_attempt is None:
1445+
block_second_attempt = True
1446+
elif block_second_attempt is True:
1447+
await block_long_running_handler.wait()
1448+
block_second_attempt = False
1449+
try:
1450+
return original_handler(*args, **kwargs)
1451+
finally:
1452+
await block_long_running_handler.clear()
1453+
await handled_long_running.set()
1454+
1455+
s.stream_handlers["long-running"] = instrumented_handle_long_running
1456+
1457+
# Submit task and wait until it executes on a
1458+
x = c.submit(
1459+
f,
1460+
before_secede,
1461+
block_secede,
1462+
block_long_running,
1463+
key="x",
1464+
)
1465+
await before_secede.wait()
1466+
1467+
# FIXME: Relying on logging is rather brittle. We should fail hard if stimulus handling fails.
1468+
with captured_logger("distributed.scheduler", logging.ERROR) as caplog:
1469+
with freeze_batched_send(a.batched_stream):
1470+
# Let x secede (and later succeed) without informing the scheduler
1471+
await block_secede.set()
1472+
await wait_for_state("x", "long-running", a)
1473+
assert not a.state.executing
1474+
assert a.state.long_running
1475+
1476+
# Cancel x while the scheduler does not know that it seceded
1477+
x.release()
1478+
await async_poll_for(lambda: not s.tasks, timeout=5)
1479+
assert not wsA.processing
1480+
assert not wsA.long_running
1481+
1482+
# Resubmit task and wait until it is resumed on a
1483+
x = c.submit(
1484+
f,
1485+
before_secede,
1486+
block_secede,
1487+
block_long_running,
1488+
key="x",
1489+
)
1490+
await wait_for_state("x", "long-running", a)
1491+
assert not a.state.executing
1492+
assert a.state.long_running
1493+
1494+
assert wsA.processing
1495+
assert not wsA.long_running
1496+
1497+
# Unblock the stream from a to the scheduler and handle the stale long-running message
1498+
await handled_long_running.wait()
1499+
1500+
assert wsA.processing
1501+
assert wsA.task_prefix_count
1502+
assert not wsA.long_running
1503+
1504+
# Clear the handler and let the scheduler handle the second long-running message
1505+
await handled_long_running.clear()
1506+
await block_long_running_handler.set()
1507+
await handled_long_running.wait()
1508+
1509+
# Assert that the handler did not fail and no state was corrupted
1510+
logs = caplog.getvalue()
1511+
assert not logs
1512+
assert not wsA.task_prefix_count
1513+
assert wsA.processing
1514+
assert wsA.long_running
1515+
1516+
await block_long_running.set()
1517+
assert await x.result() == 123
1518+
1519+
12951520
@gen_cluster(client=True, nthreads=[("", 1)], timeout=2)
12961521
async def test_secede_cancelled_or_resumed_scheduler(c, s, a):
12971522
"""Same as test_secede_cancelled_or_resumed_workerstate, but testing the interaction

distributed/tests/test_resources.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -575,11 +575,13 @@ def test_resumed_with_different_resources(ws_with_running_task, done_ev_cls):
575575
assert ws.available_resources == {"R": 0}
576576

577577
instructions = ws.handle_stimulus(
578-
ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 0.4})
578+
ComputeTaskEvent.dummy(
579+
"x", run_id=0, stimulus_id="s2", resource_restrictions={"R": 0.4}
580+
)
579581
)
580582
if prev_state == "long-running":
581583
assert instructions == [
582-
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2")
584+
LongRunningMsg(key="x", run_id=0, compute_duration=None, stimulus_id="s2")
583585
]
584586
else:
585587
assert not instructions

0 commit comments

Comments
 (0)