Skip to content

Commit 197d64d

Browse files
authored
AMM: test incremental retirements (#8501)
1 parent 774874e commit 197d64d

File tree

1 file changed

+64
-5
lines changed

1 file changed

+64
-5
lines changed

distributed/tests/test_active_memory_manager.py

+64-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from distributed.utils_test import (
2323
NO_AMM,
2424
BlockedGatherDep,
25+
BlockedGetData,
2526
assert_story,
2627
async_poll_for,
2728
captured_logger,
@@ -1129,6 +1130,60 @@ async def test_RetireWorker_faulty_recipient(c, s, w1, w2):
11291130
assert dict(w2.data) == {"x": 123, clutter.key: 456}
11301131

11311132

1133+
@gen_cluster(
1134+
client=True,
1135+
nthreads=[("", 1)] * 10,
1136+
config={
1137+
"distributed.scheduler.active-memory-manager.start": True,
1138+
"distributed.scheduler.active-memory-manager.interval": 0.05,
1139+
"distributed.scheduler.active-memory-manager.measure": "managed",
1140+
"distributed.scheduler.active-memory-manager.policies": [],
1141+
},
1142+
)
1143+
async def test_RetireWorker_mass(c, s, *workers):
1144+
"""Retire 90% of a cluster at once."""
1145+
# Note: by using scatter instead of submit/map, we're also testing that tasks
1146+
# aren't being recomputed
1147+
data = await c.scatter(range(100))
1148+
for w in workers:
1149+
assert len(w.data) == 10
1150+
1151+
await c.retire_workers([w.address for w in workers[:-1]])
1152+
assert set(s.workers) == {workers[-1].address}
1153+
assert len(workers[-1].data) == 100
1154+
1155+
1156+
@gen_cluster(
1157+
client=True,
1158+
config={
1159+
"distributed.scheduler.active-memory-manager.start": True,
1160+
"distributed.scheduler.active-memory-manager.interval": 0.05,
1161+
"distributed.scheduler.active-memory-manager.measure": "managed",
1162+
"distributed.scheduler.active-memory-manager.policies": [],
1163+
},
1164+
)
1165+
async def test_RetireWorker_incremental(c, s, w2, w3):
1166+
"""Retire worker w1; this causes its keys to be replicated onto w2.
1167+
Before that can happen, retire w2 too.
1168+
"""
1169+
async with BlockedGetData(s.address) as w1:
1170+
# Note: by using scatter instead of submit/map, we're also testing that tasks
1171+
# aren't being recomputed
1172+
x = await c.scatter({"x": 1}, workers=[w1.address])
1173+
y = await c.scatter({"y": 2}, workers=[w3.address])
1174+
1175+
# Because w2's memory is lower than w3, AMM will choose w2
1176+
retire1 = asyncio.create_task(c.retire_workers([w1.address]))
1177+
await w1.in_get_data.wait()
1178+
assert w2.state.tasks["x"].state == "flight"
1179+
await c.retire_workers([w2.address])
1180+
1181+
w1.block_get_data.set()
1182+
await retire1
1183+
assert set(s.workers) == {w3.address}
1184+
assert set(w3.data) == {"x", "y"}
1185+
1186+
11321187
class Counter:
11331188
def __init__(self):
11341189
self.n = 0
@@ -1223,7 +1278,7 @@ def run(self):
12231278
self.manager.policies.remove(self)
12241279

12251280

1226-
async def tensordot_stress(c):
1281+
async def tensordot_stress(c, s):
12271282
da = pytest.importorskip("dask.array")
12281283

12291284
rng = da.random.RandomState(0)
@@ -1234,6 +1289,10 @@ async def tensordot_stress(c):
12341289
b = (a @ a.T).sum().round(3)
12351290
assert await c.compute(b) == 245.394
12361291

1292+
# Test that we didn't recompute any tasks during the stress test
1293+
await async_poll_for(lambda: not s.tasks, timeout=5)
1294+
assert sum(t.start == "memory" for t in s.transition_log) == 1639
1295+
12371296

12381297
@pytest.mark.slow
12391298
@gen_cluster(
@@ -1245,7 +1304,7 @@ async def test_noamm_stress(c, s, *workers):
12451304
"""Test the tensordot_stress helper without AMM. This is to figure out if a
12461305
stability issue is AMM-specific or not.
12471306
"""
1248-
await tensordot_stress(c)
1307+
await tensordot_stress(c, s)
12491308

12501309

12511310
@pytest.mark.slow
@@ -1267,7 +1326,7 @@ async def test_drop_stress(c, s, *workers):
12671326
12681327
See also: test_ReduceReplicas_stress
12691328
"""
1270-
await tensordot_stress(c)
1329+
await tensordot_stress(c, s)
12711330

12721331

12731332
@pytest.mark.slow
@@ -1288,7 +1347,7 @@ async def test_ReduceReplicas_stress(c, s, *workers):
12881347
test_drop_stress above, this test does not stop running after a few seconds - the
12891348
policy must not disrupt the computation too much.
12901349
"""
1291-
await tensordot_stress(c)
1350+
await tensordot_stress(c, s)
12921351

12931352

12941353
@pytest.mark.slow
@@ -1316,7 +1375,7 @@ async def test_RetireWorker_stress(c, s, *workers, use_ReduceReplicas):
13161375
random.shuffle(addrs)
13171376
print(f"Removing all workers except {addrs[9]}")
13181377

1319-
tasks = [asyncio.create_task(tensordot_stress(c))]
1378+
tasks = [asyncio.create_task(tensordot_stress(c, s))]
13201379
await asyncio.sleep(1)
13211380
tasks.append(asyncio.create_task(c.retire_workers(addrs[0:2])))
13221381
await asyncio.sleep(1)

0 commit comments

Comments
 (0)