Skip to content

Commit 0d980c2

Browse files
committed
refactor(work-stealing): separated balancing algorithm
1 parent 48c8984 commit 0d980c2

File tree

1 file changed

+148
-121
lines changed

1 file changed

+148
-121
lines changed

distributed/stealing.py

+148-121
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,150 @@
6060
}
6161

6262

63+
class WorkStealingBalancer:
64+
def __init__(self, scheduler: Scheduler, work_stealing: WorkStealing):
65+
self.scheduler = scheduler
66+
self.work_stealing = work_stealing
67+
68+
self._logs: list[tuple[Any, ...]] = []
69+
self._start_time: float
70+
71+
@log_errors
72+
def balance(self) -> None:
73+
self._start_time = time()
74+
return self._run()
75+
76+
def _run(self) -> None:
77+
s = self.scheduler
78+
i = 0
79+
# Paused and closing workers must never become thieves
80+
if not (potential_thieves := self.get_potential_thieves()) or not (
81+
potential_victims := self.get_potential_victims(potential_thieves)
82+
):
83+
return
84+
assert potential_victims
85+
assert potential_thieves
86+
for level, _ in enumerate(self.work_stealing.cost_multipliers):
87+
if not potential_thieves:
88+
break
89+
for victim in list(potential_victims):
90+
stealable = self._get_stealable_tasks(victim, level)
91+
if not stealable or not potential_thieves:
92+
continue
93+
94+
for ts in list(stealable):
95+
if not potential_thieves:
96+
break
97+
if (
98+
ts not in self.work_stealing.key_stealable
99+
or ts.processing_on is not victim
100+
or ts not in victim.processing
101+
):
102+
# FIXME: Instead of discarding here, clean up stealable properly
103+
stealable.discard(ts)
104+
continue
105+
i += 1
106+
if not (thief := _get_thief(s, ts, potential_thieves)):
107+
continue
108+
109+
self.try_steal_task(
110+
ts, thief, victim, level, potential_thieves, stealable
111+
)
112+
113+
s.check_idle_saturated(victim, occ=self._combined_occupancy(victim))
114+
115+
if self._logs:
116+
self.work_stealing.log(("request", self._logs))
117+
self.work_stealing.count += 1
118+
stop = time()
119+
if s.digests:
120+
s.digests["steal-duration"].add(stop - self._start_time)
121+
122+
def _get_stealable_tasks(self, victim: WorkerState, level: int) -> set[TaskState]:
123+
return self.work_stealing.stealable[victim.address][level]
124+
125+
def try_steal_task(
126+
self,
127+
ts: TaskState,
128+
thief: WorkerState,
129+
victim: WorkerState,
130+
level: int,
131+
potential_thieves: set[WorkerState],
132+
stealable: set[TaskState],
133+
) -> None:
134+
s = self.scheduler
135+
occ_thief = self._combined_occupancy(thief)
136+
occ_victim = self._combined_occupancy(victim)
137+
comm_cost_thief = s.get_comm_cost(ts, thief)
138+
comm_cost_victim = s.get_comm_cost(ts, victim)
139+
compute = s.get_task_duration(ts)
140+
if (
141+
occ_thief + comm_cost_thief + compute
142+
<= occ_victim - (comm_cost_victim + compute) / 2
143+
):
144+
self.work_stealing.move_task_request(ts, victim, thief)
145+
cost = compute + comm_cost_victim
146+
self._logs.append(
147+
(
148+
self._start_time,
149+
level,
150+
ts.key,
151+
cost,
152+
victim.address,
153+
occ_victim,
154+
thief.address,
155+
occ_thief,
156+
)
157+
)
158+
self.work_stealing.metrics["request_count_total"][level] += 1
159+
self.work_stealing.metrics["request_cost_total"][level] += cost
160+
occ_thief = self._combined_occupancy(thief)
161+
nproc_thief = self._combined_nprocessing(thief)
162+
if not s.is_unoccupied(thief, occ_thief, nproc_thief):
163+
potential_thieves.discard(thief)
164+
# FIXME: move_task_request already implements some logic
165+
# for removing ts from stealable. If we made sure to
166+
# properly clean up, we would not need this
167+
stealable.discard(ts)
168+
169+
def get_potential_thieves(self) -> set[WorkerState]:
170+
s = self.scheduler
171+
potential_thieves = set(s.idle.values())
172+
if not potential_thieves or len(potential_thieves) == len(s.workers):
173+
return set()
174+
return potential_thieves
175+
176+
def get_potential_victims(
177+
self, potential_thieves: set[WorkerState]
178+
) -> set[WorkerState]:
179+
s = self.scheduler
180+
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated
181+
if not potential_victims:
182+
potential_victims = topk(
183+
10, s.workers.values(), key=self._combined_occupancy
184+
)
185+
potential_victims = [
186+
ws
187+
for ws in potential_victims
188+
if self._combined_occupancy(ws) > 0.2
189+
and self._combined_nprocessing(ws) > ws.nthreads
190+
and ws not in potential_thieves
191+
]
192+
if not potential_victims:
193+
return set()
194+
if len(potential_victims) < 20:
195+
potential_victims = sorted(
196+
potential_victims, key=self._combined_occupancy, reverse=True
197+
)
198+
return set(potential_victims)
199+
200+
def _combined_occupancy(self, ws: WorkerState) -> float:
201+
return ws.occupancy + self.work_stealing.in_flight_occupancy[ws]
202+
203+
def _combined_nprocessing(self, ws: WorkerState) -> int:
204+
return len(ws.processing) + self.work_stealing.in_flight_tasks[ws]
205+
206+
63207
class InFlightInfo(TypedDict):
64208
victim: WorkerState
65209
thief: WorkerState
@@ -120,6 +264,8 @@ def __init__(self, scheduler: Scheduler):
120264
self._request_counter = 0
121265
self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
122266

267+
self._balancer = WorkStealingBalancer
268+
123269
async def start(self, scheduler: Any = None) -> None:
124270
"""Start the background coroutine to balance the tasks on the cluster.
125271
Idempotent.
@@ -398,127 +544,8 @@ async def move_task_confirm(
398544
self.scheduler.check_idle_saturated(thief)
399545
self.scheduler.check_idle_saturated(victim)
400546

401-
@log_errors
402-
def balance(self) -> None:
403-
s = self.scheduler
404-
log = []
405-
start = time()
406-
407-
i = 0
408-
# Paused and closing workers must never become thieves
409-
if not (potential_thieves := self.get_potential_thieves()) or not (
410-
potential_victims := self.get_potential_victims(potential_thieves)
411-
):
412-
return
413-
assert potential_victims
414-
assert potential_thieves
415-
for level, _ in enumerate(self.cost_multipliers):
416-
if not potential_thieves:
417-
break
418-
for victim in list(potential_victims):
419-
stealable = self.stealable[victim.address][level]
420-
if not stealable or not potential_thieves:
421-
continue
422-
for ts in list(stealable):
423-
if not potential_thieves:
424-
break
425-
if (
426-
ts not in self.key_stealable
427-
or ts.processing_on is not victim
428-
or ts not in victim.processing
429-
):
430-
# FIXME: Instead of discarding here, clean up stealable properly
431-
stealable.discard(ts)
432-
continue
433-
i += 1
434-
if not (thief := _get_thief(s, ts, potential_thieves)):
435-
continue
436-
437-
occ_thief = self._combined_occupancy(thief)
438-
occ_victim = self._combined_occupancy(victim)
439-
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
440-
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
441-
compute = self.scheduler.get_task_duration(ts)
442-
443-
if (
444-
occ_thief + comm_cost_thief + compute
445-
<= occ_victim - (comm_cost_victim + compute) / 2
446-
):
447-
self.move_task_request(ts, victim, thief)
448-
cost = compute + comm_cost_victim
449-
log.append(
450-
(
451-
start,
452-
level,
453-
ts.key,
454-
cost,
455-
victim.address,
456-
occ_victim,
457-
thief.address,
458-
occ_thief,
459-
)
460-
)
461-
self.metrics["request_count_total"][level] += 1
462-
self.metrics["request_cost_total"][level] += cost
463-
464-
occ_thief = self._combined_occupancy(thief)
465-
nproc_thief = self._combined_nprocessing(thief)
466-
467-
if not self.scheduler.is_unoccupied(
468-
thief, occ_thief, nproc_thief
469-
):
470-
potential_thieves.discard(thief)
471-
# FIXME: move_task_request already implements some logic
472-
# for removing ts from stealable. If we made sure to
473-
# properly clean up, we would not need this
474-
stealable.discard(ts)
475-
self.scheduler.check_idle_saturated(
476-
victim, occ=self._combined_occupancy(victim)
477-
)
478-
479-
if log:
480-
self.log(("request", log))
481-
self.count += 1
482-
stop = time()
483-
if s.digests:
484-
s.digests["steal-duration"].add(stop - start)
485-
486-
def get_potential_thieves(self) -> set[WorkerState]:
487-
s = self.scheduler
488-
potential_thieves = set(s.idle.values())
489-
if not potential_thieves or len(potential_thieves) == len(s.workers):
490-
return set()
491-
return potential_thieves
492-
493-
def get_potential_victims(
494-
self, potential_thieves: set[WorkerState]
495-
) -> set[WorkerState]:
496-
s = self.scheduler
497-
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated
498-
if not potential_victims:
499-
potential_victims = topk(
500-
10, s.workers.values(), key=self._combined_occupancy
501-
)
502-
potential_victims = [
503-
ws
504-
for ws in potential_victims
505-
if self._combined_occupancy(ws) > 0.2
506-
and self._combined_nprocessing(ws) > ws.nthreads
507-
and ws not in potential_thieves
508-
]
509-
if not potential_victims:
510-
return set()
511-
if len(potential_victims) < 20:
512-
potential_victims = sorted(
513-
potential_victims, key=self._combined_occupancy, reverse=True
514-
)
515-
return set(potential_victims)
516-
517-
def _combined_occupancy(self, ws: WorkerState) -> float:
518-
return ws.occupancy + self.in_flight_occupancy[ws]
519-
520-
def _combined_nprocessing(self, ws: WorkerState) -> int:
521-
return len(ws.processing) + self.in_flight_tasks[ws]
547+
def balance(self):
548+
self._balancer(self.scheduler, self).balance()
522549

523550
def restart(self, scheduler: Any) -> None:
524551
for stealable in self.stealable.values():

0 commit comments

Comments
 (0)