Skip to content

Commit c3958d9

Browse files
committed
refactor(adaptive-core): improved readability recommendations
1 parent c98cd09 commit c3958d9

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

distributed/deploy/adaptive.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import dask.config
99
from dask.utils import parse_timedelta
1010

11-
from distributed.deploy.adaptive_core import AdaptiveCore
11+
from distributed.deploy.adaptive_core import AdaptiveCore, Recommendation
1212
from distributed.protocol import pickle
1313
from distributed.utils import log_errors
1414

@@ -152,7 +152,7 @@ async def target(self):
152152
target_duration=self.target_duration
153153
)
154154

155-
async def recommendations(self, target: int) -> dict:
155+
async def recommendations(self, target: int) -> Recommendation:
156156
if len(self.plan) != len(self.requested):
157157
# Ensure that the number of planned and requested workers
158158
# are in sync before making recommendations.

distributed/deploy/adaptive_core.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from collections import defaultdict, deque
66
from collections.abc import Iterable
77
from datetime import timedelta
8-
from typing import TYPE_CHECKING, cast
8+
from typing import TYPE_CHECKING, Literal, TypedDict, cast
99

1010
import tlz as toolz
1111
from tornado.ioloop import IOLoop
12+
from typing_extensions import NotRequired
1213

1314
import dask.config
1415
from dask.utils import parse_timedelta
@@ -23,6 +24,15 @@
2324
logger = logging.getLogger(__name__)
2425

2526

27+
RecommendationStatus = Literal["up", "down", "same"]
28+
29+
30+
class Recommendation(TypedDict):
31+
status: RecommendationStatus
32+
workers: NotRequired[set[WorkerState]]
33+
n: NotRequired[int]
34+
35+
2636
class AdaptiveCore:
2737
"""
2838
The core logic for adaptive deployments, with none of the cluster details
@@ -169,13 +179,13 @@ async def safe_target(self) -> int:
169179

170180
return n
171181

172-
async def scale_down(self, n: int) -> None:
182+
async def scale_down(self, workers: Iterable) -> None:
173183
raise NotImplementedError()
174184

175-
async def scale_up(self, workers: Iterable) -> None:
185+
async def scale_up(self, n: int) -> None:
176186
raise NotImplementedError()
177187

178-
async def recommendations(self, target: int) -> dict:
188+
async def recommendations(self, target: int) -> Recommendation:
179189
"""
180190
Make scale up/down recommendations based on current state and target
181191
"""
@@ -185,11 +195,11 @@ async def recommendations(self, target: int) -> dict:
185195

186196
if target == len(plan):
187197
self.close_counts.clear()
188-
return {"status": "same"}
198+
return Recommendation(status="same")
189199

190200
if target > len(plan):
191201
self.close_counts.clear()
192-
return {"status": "up", "n": target}
202+
return Recommendation(status="up", n=target)
193203

194204
# target < len(plan)
195205
not_yet_arrived = requested - observed
@@ -212,9 +222,9 @@ async def recommendations(self, target: int) -> dict:
212222
del self.close_counts[k]
213223

214224
if firmly_close:
215-
return {"status": "down", "workers": list(firmly_close)}
225+
return Recommendation(status="down", workers=firmly_close)
216226
else:
217-
return {"status": "same"}
227+
return Recommendation(status="same")
218228

219229
async def adapt(self) -> None:
220230
"""
@@ -229,18 +239,16 @@ async def adapt(self) -> None:
229239

230240
try:
231241
target = await self.safe_target()
232-
recommendations = await self.recommendations(target)
233-
234-
if recommendations["status"] != "same":
235-
self.log.append((time(), dict(recommendations)))
242+
recommendation = await self.recommendations(target)
236243

237-
status = recommendations.pop("status")
238-
if status == "same":
244+
if recommendation["status"] == "same":
239245
return
240-
if status == "up":
241-
await self.scale_up(**recommendations)
242-
if status == "down":
243-
await self.scale_down(**recommendations)
246+
else:
247+
self.log.append((time(), cast(dict, recommendation)))
248+
if recommendation["status"] == "up":
249+
await self.scale_up(recommendation["n"])
250+
elif recommendation["status"] == "down":
251+
await self.scale_down(recommendation["workers"])
244252
except OSError:
245253
if status != "down":
246254
logger.error("Adaptive stopping due to error", exc_info=True)

0 commit comments

Comments
 (0)