5
5
from collections import defaultdict , deque
6
6
from collections .abc import Iterable
7
7
from datetime import timedelta
8
- from typing import TYPE_CHECKING , cast
8
+ from typing import TYPE_CHECKING , Literal , TypedDict , cast
9
9
10
10
import tlz as toolz
11
11
from tornado .ioloop import IOLoop
12
+ from typing_extensions import NotRequired
12
13
13
14
import dask .config
14
15
from dask .utils import parse_timedelta
23
24
logger = logging .getLogger (__name__ )
24
25
25
26
27
+ RecommendationStatus = Literal ["up" , "down" , "same" ]
28
+
29
+
30
+ class Recommendation (TypedDict ):
31
+ status : RecommendationStatus
32
+ workers : NotRequired [set [WorkerState ]]
33
+ n : NotRequired [int ]
34
+
35
+
26
36
class AdaptiveCore :
27
37
"""
28
38
The core logic for adaptive deployments, with none of the cluster details
@@ -169,13 +179,13 @@ async def safe_target(self) -> int:
169
179
170
180
return n
171
181
172
- async def scale_down (self , n : int ) -> None :
182
+ async def scale_down (self , workers : Iterable ) -> None :
173
183
raise NotImplementedError ()
174
184
175
- async def scale_up (self , workers : Iterable ) -> None :
185
+ async def scale_up (self , n : int ) -> None :
176
186
raise NotImplementedError ()
177
187
178
- async def recommendations (self , target : int ) -> dict :
188
+ async def recommendations (self , target : int ) -> Recommendation :
179
189
"""
180
190
Make scale up/down recommendations based on current state and target
181
191
"""
@@ -185,11 +195,11 @@ async def recommendations(self, target: int) -> dict:
185
195
186
196
if target == len (plan ):
187
197
self .close_counts .clear ()
188
- return { " status" : " same"}
198
+ return Recommendation ( status = " same")
189
199
190
200
if target > len (plan ):
191
201
self .close_counts .clear ()
192
- return { " status" : " up" , "n" : target }
202
+ return Recommendation ( status = " up" , n = target )
193
203
194
204
# target < len(plan)
195
205
not_yet_arrived = requested - observed
@@ -212,9 +222,9 @@ async def recommendations(self, target: int) -> dict:
212
222
del self .close_counts [k ]
213
223
214
224
if firmly_close :
215
- return { " status" : " down" , " workers" : list ( firmly_close )}
225
+ return Recommendation ( status = " down" , workers = firmly_close )
216
226
else :
217
- return { " status" : " same"}
227
+ return Recommendation ( status = " same")
218
228
219
229
async def adapt (self ) -> None :
220
230
"""
@@ -229,18 +239,16 @@ async def adapt(self) -> None:
229
239
230
240
try :
231
241
target = await self .safe_target ()
232
- recommendations = await self .recommendations (target )
233
-
234
- if recommendations ["status" ] != "same" :
235
- self .log .append ((time (), dict (recommendations )))
242
+ recommendation = await self .recommendations (target )
236
243
237
- status = recommendations .pop ("status" )
238
- if status == "same" :
244
+ if recommendation ["status" ] == "same" :
239
245
return
240
- if status == "up" :
241
- await self .scale_up (** recommendations )
242
- if status == "down" :
243
- await self .scale_down (** recommendations )
246
+ else :
247
+ self .log .append ((time (), cast (dict , recommendation )))
248
+ if recommendation ["status" ] == "up" :
249
+ await self .scale_up (recommendation ["n" ])
250
+ elif recommendation ["status" ] == "down" :
251
+ await self .scale_down (recommendation ["workers" ])
244
252
except OSError :
245
253
if status != "down" :
246
254
logger .error ("Adaptive stopping due to error" , exc_info = True )
0 commit comments