22
22
from distributed .utils_test import (
23
23
NO_AMM ,
24
24
BlockedGatherDep ,
25
+ BlockedGetData ,
25
26
assert_story ,
26
27
async_poll_for ,
27
28
captured_logger ,
@@ -1129,6 +1130,60 @@ async def test_RetireWorker_faulty_recipient(c, s, w1, w2):
1129
1130
assert dict (w2 .data ) == {"x" : 123 , clutter .key : 456 }
1130
1131
1131
1132
1133
+ @gen_cluster (
1134
+ client = True ,
1135
+ nthreads = [("" , 1 )] * 10 ,
1136
+ config = {
1137
+ "distributed.scheduler.active-memory-manager.start" : True ,
1138
+ "distributed.scheduler.active-memory-manager.interval" : 0.05 ,
1139
+ "distributed.scheduler.active-memory-manager.measure" : "managed" ,
1140
+ "distributed.scheduler.active-memory-manager.policies" : [],
1141
+ },
1142
+ )
1143
+ async def test_RetireWorker_mass (c , s , * workers ):
1144
+ """Retire 90% of a cluster at once."""
1145
+ # Note: by using scatter instead of submit/map, we're also testing that tasks
1146
+ # aren't being recomputed
1147
+ data = await c .scatter (range (100 ))
1148
+ for w in workers :
1149
+ assert len (w .data ) == 10
1150
+
1151
+ await c .retire_workers ([w .address for w in workers [:- 1 ]])
1152
+ assert set (s .workers ) == {workers [- 1 ].address }
1153
+ assert len (workers [- 1 ].data ) == 100
1154
+
1155
+
1156
+ @gen_cluster (
1157
+ client = True ,
1158
+ config = {
1159
+ "distributed.scheduler.active-memory-manager.start" : True ,
1160
+ "distributed.scheduler.active-memory-manager.interval" : 0.05 ,
1161
+ "distributed.scheduler.active-memory-manager.measure" : "managed" ,
1162
+ "distributed.scheduler.active-memory-manager.policies" : [],
1163
+ },
1164
+ )
1165
+ async def test_RetireWorker_incremental (c , s , w2 , w3 ):
1166
+ """Retire worker w1; this causes its keys to be replicated onto w2.
1167
+ Before that can happen, retire w2 too.
1168
+ """
1169
+ async with BlockedGetData (s .address ) as w1 :
1170
+ # Note: by using scatter instead of submit/map, we're also testing that tasks
1171
+ # aren't being recomputed
1172
+ x = await c .scatter ({"x" : 1 }, workers = [w1 .address ])
1173
+ y = await c .scatter ({"y" : 2 }, workers = [w3 .address ])
1174
+
1175
+ # Because w2's memory is lower than w3, AMM will choose w2
1176
+ retire1 = asyncio .create_task (c .retire_workers ([w1 .address ]))
1177
+ await w1 .in_get_data .wait ()
1178
+ assert w2 .state .tasks ["x" ].state == "flight"
1179
+ await c .retire_workers ([w2 .address ])
1180
+
1181
+ w1 .block_get_data .set ()
1182
+ await retire1
1183
+ assert set (s .workers ) == {w3 .address }
1184
+ assert set (w3 .data ) == {"x" , "y" }
1185
+
1186
+
1132
1187
class Counter :
1133
1188
def __init__ (self ):
1134
1189
self .n = 0
@@ -1223,7 +1278,7 @@ def run(self):
1223
1278
self .manager .policies .remove (self )
1224
1279
1225
1280
1226
- async def tensordot_stress (c ):
1281
+ async def tensordot_stress (c , s ):
1227
1282
da = pytest .importorskip ("dask.array" )
1228
1283
1229
1284
rng = da .random .RandomState (0 )
@@ -1234,6 +1289,10 @@ async def tensordot_stress(c):
1234
1289
b = (a @ a .T ).sum ().round (3 )
1235
1290
assert await c .compute (b ) == 245.394
1236
1291
1292
+ # Test that we didn't recompute any tasks during the stress test
1293
+ await async_poll_for (lambda : not s .tasks , timeout = 5 )
1294
+ assert sum (t .start == "memory" for t in s .transition_log ) == 1639
1295
+
1237
1296
1238
1297
@pytest .mark .slow
1239
1298
@gen_cluster (
@@ -1245,7 +1304,7 @@ async def test_noamm_stress(c, s, *workers):
1245
1304
"""Test the tensordot_stress helper without AMM. This is to figure out if a
1246
1305
stability issue is AMM-specific or not.
1247
1306
"""
1248
- await tensordot_stress (c )
1307
+ await tensordot_stress (c , s )
1249
1308
1250
1309
1251
1310
@pytest .mark .slow
@@ -1267,7 +1326,7 @@ async def test_drop_stress(c, s, *workers):
1267
1326
1268
1327
See also: test_ReduceReplicas_stress
1269
1328
"""
1270
- await tensordot_stress (c )
1329
+ await tensordot_stress (c , s )
1271
1330
1272
1331
1273
1332
@pytest .mark .slow
@@ -1288,7 +1347,7 @@ async def test_ReduceReplicas_stress(c, s, *workers):
1288
1347
test_drop_stress above, this test does not stop running after a few seconds - the
1289
1348
policy must not disrupt the computation too much.
1290
1349
"""
1291
- await tensordot_stress (c )
1350
+ await tensordot_stress (c , s )
1292
1351
1293
1352
1294
1353
@pytest .mark .slow
@@ -1316,7 +1375,7 @@ async def test_RetireWorker_stress(c, s, *workers, use_ReduceReplicas):
1316
1375
random .shuffle (addrs )
1317
1376
print (f"Removing all workers except { addrs [9 ]} " )
1318
1377
1319
- tasks = [asyncio .create_task (tensordot_stress (c ))]
1378
+ tasks = [asyncio .create_task (tensordot_stress (c , s ))]
1320
1379
await asyncio .sleep (1 )
1321
1380
tasks .append (asyncio .create_task (c .retire_workers (addrs [0 :2 ])))
1322
1381
await asyncio .sleep (1 )
0 commit comments