Skip to content

Commit 7034a1b

Browse files
committed
fix(adaptive): fixed comparison in recommendations when scaling down
1 parent 0111087 commit 7034a1b

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

distributed/deploy/adaptive.py

+21
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
self.cluster = cluster
9696
self.worker_key = worker_key
9797
self._workers_to_close_kwargs = kwargs
98+
self._worker_name_mapping = {}
9899

99100
if interval is None:
100101
interval = dask.config.get("distributed.adaptive.interval")
@@ -131,6 +132,26 @@ def requested(self):
131132
def observed(self):
132133
return self.cluster.observed
133134

135+
@property
136+
def observed_name_mapped(self):
137+
self._assign_hosts_to_names()
138+
return self._worker_name_mapping
139+
140+
def _assign_hosts_to_names(self) -> None:
141+
unassigned_worker_names = self._unassigned_worker_names()
142+
for worker_address in self.cluster.scheduler_info["workers"].keys():
143+
if worker_address not in self._worker_name_mapping.values():
144+
assert unassigned_worker_names
145+
self._worker_name_mapping[
146+
unassigned_worker_names.pop()
147+
] = worker_address
148+
for worker_name, worker_address in self._worker_name_mapping.copy().items():
149+
if worker_address not in self.cluster.scheduler_info["workers"].keys():
150+
del self._worker_name_mapping[worker_name]
151+
152+
def _unassigned_worker_names(self) -> set:
153+
return self.requested - self._worker_name_mapping.keys()
154+
134155
async def target(self):
135156
"""
136157
Determine target number of workers that should exist.

distributed/deploy/adaptive_core.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class AdaptiveCore:
8787
plan: set[WorkerState]
8888
requested: set[WorkerState]
8989
observed: set[WorkerState]
90+
observed_name_mapped: dict[str, str]
9091
close_counts: defaultdict[WorkerState, int]
9192
_adapting: bool
9293
log: deque[tuple[float, dict]]
@@ -130,6 +131,7 @@ async def _adapt():
130131
self.plan = set()
131132
self.requested = set()
132133
self.observed = set()
134+
self.observed_name_mapped = {}
133135
except Exception:
134136
pass
135137

@@ -181,7 +183,7 @@ async def recommendations(self, target: int) -> dict:
181183
"""
182184
plan = self.plan
183185
requested = self.requested
184-
observed = self.observed
186+
observed = self.observed_name_mapped
185187

186188
if target == len(plan):
187189
self.close_counts.clear()
@@ -192,14 +194,16 @@ async def recommendations(self, target: int) -> dict:
192194
return {"status": "up", "n": target}
193195

194196
# target < len(plan)
195-
not_yet_arrived = requested - observed
197+
not_yet_arrived = requested - observed.keys()
196198
to_close = set()
197199
if not_yet_arrived:
198200
to_close.update(toolz.take(len(plan) - target, not_yet_arrived))
199201

200202
if target < len(plan) - len(to_close):
201203
L = await self.workers_to_close(target=target)
202-
to_close.update(L)
204+
to_close.update(
205+
[key for key, value in observed.items() for name in L if value == name]
206+
)
203207

204208
firmly_close = set()
205209
for w in to_close:

0 commit comments

Comments
 (0)