|
60 | 60 | }
|
61 | 61 |
|
62 | 62 |
|
| 63 | +class WorkStealingBalancer: |
| 64 | + def __init__(self, scheduler: Scheduler, work_stealing: WorkStealing): |
| 65 | + self.scheduler = scheduler |
| 66 | + self.work_stealing = work_stealing |
| 67 | + |
| 68 | + self._logs: list[tuple[Any, ...]] = [] |
| 69 | + self._start_time: float |
| 70 | + |
| 71 | + @log_errors |
| 72 | + def balance(self) -> None: |
| 73 | + self._start_time = time() |
| 74 | + return self._run() |
| 75 | + |
| 76 | + def _run(self) -> None: |
| 77 | + s = self.scheduler |
| 78 | + i = 0 |
| 79 | + # Paused and closing workers must never become thieves |
| 80 | + if not (potential_thieves := self.get_potential_thieves()) or not ( |
| 81 | + potential_victims := self.get_potential_victims(potential_thieves) |
| 82 | + ): |
| 83 | + return |
| 84 | + assert potential_victims |
| 85 | + assert potential_thieves |
| 86 | + for level, _ in enumerate(self.work_stealing.cost_multipliers): |
| 87 | + if not potential_thieves: |
| 88 | + break |
| 89 | + for victim in list(potential_victims): |
| 90 | + stealable = self._get_stealable_tasks(victim, level) |
| 91 | + if not stealable or not potential_thieves: |
| 92 | + continue |
| 93 | + |
| 94 | + for ts in list(stealable): |
| 95 | + if not potential_thieves: |
| 96 | + break |
| 97 | + if ( |
| 98 | + ts not in self.work_stealing.key_stealable |
| 99 | + or ts.processing_on is not victim |
| 100 | + or ts not in victim.processing |
| 101 | + ): |
| 102 | + # FIXME: Instead of discarding here, clean up stealable properly |
| 103 | + stealable.discard(ts) |
| 104 | + continue |
| 105 | + i += 1 |
| 106 | + if not (thief := _get_thief(s, ts, potential_thieves)): |
| 107 | + continue |
| 108 | + |
| 109 | + self.try_steal_task( |
| 110 | + ts, thief, victim, level, potential_thieves, stealable |
| 111 | + ) |
| 112 | + |
| 113 | + s.check_idle_saturated(victim, occ=self._combined_occupancy(victim)) |
| 114 | + |
| 115 | + if self._logs: |
| 116 | + self.work_stealing.log(("request", self._logs)) |
| 117 | + self.work_stealing.count += 1 |
| 118 | + stop = time() |
| 119 | + if s.digests: |
| 120 | + s.digests["steal-duration"].add(stop - self._start_time) |
| 121 | + |
| 122 | + def _get_stealable_tasks(self, victim: WorkerState, level: int) -> set[TaskState]: |
| 123 | + return self.work_stealing.stealable[victim.address][level] |
| 124 | + |
| 125 | + def try_steal_task( |
| 126 | + self, |
| 127 | + ts: TaskState, |
| 128 | + thief: WorkerState, |
| 129 | + victim: WorkerState, |
| 130 | + level: int, |
| 131 | + potential_thieves: set[WorkerState], |
| 132 | + stealable: set[TaskState], |
| 133 | + ) -> None: |
| 134 | + s = self.scheduler |
| 135 | + occ_thief = self._combined_occupancy(thief) |
| 136 | + occ_victim = self._combined_occupancy(victim) |
| 137 | + comm_cost_thief = s.get_comm_cost(ts, thief) |
| 138 | + comm_cost_victim = s.get_comm_cost(ts, victim) |
| 139 | + compute = s.get_task_duration(ts) |
| 140 | + if ( |
| 141 | + occ_thief + comm_cost_thief + compute |
| 142 | + <= occ_victim - (comm_cost_victim + compute) / 2 |
| 143 | + ): |
| 144 | + self.work_stealing.move_task_request(ts, victim, thief) |
| 145 | + cost = compute + comm_cost_victim |
| 146 | + self._logs.append( |
| 147 | + ( |
| 148 | + self._start_time, |
| 149 | + level, |
| 150 | + ts.key, |
| 151 | + cost, |
| 152 | + victim.address, |
| 153 | + occ_victim, |
| 154 | + thief.address, |
| 155 | + occ_thief, |
| 156 | + ) |
| 157 | + ) |
| 158 | + self.work_stealing.metrics["request_count_total"][level] += 1 |
| 159 | + self.work_stealing.metrics["request_cost_total"][level] += cost |
| 160 | + occ_thief = self._combined_occupancy(thief) |
| 161 | + nproc_thief = self._combined_nprocessing(thief) |
| 162 | + if not s.is_unoccupied(thief, occ_thief, nproc_thief): |
| 163 | + potential_thieves.discard(thief) |
| 164 | + # FIXME: move_task_request already implements some logic |
| 165 | + # for removing ts from stealable. If we made sure to |
| 166 | + # properly clean up, we would not need this |
| 167 | + stealable.discard(ts) |
| 168 | + |
| 169 | + def get_potential_thieves(self) -> set[WorkerState]: |
| 170 | + s = self.scheduler |
| 171 | + potential_thieves = set(s.idle.values()) |
| 172 | + if not potential_thieves or len(potential_thieves) == len(s.workers): |
| 173 | + return set() |
| 174 | + return potential_thieves |
| 175 | + |
| 176 | + def get_potential_victims( |
| 177 | + self, potential_thieves: set[WorkerState] |
| 178 | + ) -> set[WorkerState]: |
| 179 | + s = self.scheduler |
| 180 | + potential_victims: set[WorkerState] | list[WorkerState] = s.saturated |
| 181 | + if not potential_victims: |
| 182 | + potential_victims = topk( |
| 183 | + 10, s.workers.values(), key=self._combined_occupancy |
| 184 | + ) |
| 185 | + potential_victims = [ |
| 186 | + ws |
| 187 | + for ws in potential_victims |
| 188 | + if self._combined_occupancy(ws) > 0.2 |
| 189 | + and self._combined_nprocessing(ws) > ws.nthreads |
| 190 | + and ws not in potential_thieves |
| 191 | + ] |
| 192 | + if not potential_victims: |
| 193 | + return set() |
| 194 | + if len(potential_victims) < 20: |
| 195 | + potential_victims = sorted( |
| 196 | + potential_victims, key=self._combined_occupancy, reverse=True |
| 197 | + ) |
| 198 | + return set(potential_victims) |
| 199 | + |
| 200 | + def _combined_occupancy(self, ws: WorkerState) -> float: |
| 201 | + return ws.occupancy + self.work_stealing.in_flight_occupancy[ws] |
| 202 | + |
| 203 | + def _combined_nprocessing(self, ws: WorkerState) -> int: |
| 204 | + return len(ws.processing) + self.work_stealing.in_flight_tasks[ws] |
| 205 | + |
| 206 | + |
63 | 207 | class InFlightInfo(TypedDict):
|
64 | 208 | victim: WorkerState
|
65 | 209 | thief: WorkerState
|
@@ -120,6 +264,8 @@ def __init__(self, scheduler: Scheduler):
|
120 | 264 | self._request_counter = 0
|
121 | 265 | self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
|
122 | 266 |
|
| 267 | + self._balancer = WorkStealingBalancer |
| 268 | + |
123 | 269 | async def start(self, scheduler: Any = None) -> None:
|
124 | 270 | """Start the background coroutine to balance the tasks on the cluster.
|
125 | 271 | Idempotent.
|
@@ -398,127 +544,8 @@ async def move_task_confirm(
|
398 | 544 | self.scheduler.check_idle_saturated(thief)
|
399 | 545 | self.scheduler.check_idle_saturated(victim)
|
400 | 546 |
|
401 |
| - @log_errors |
402 |
| - def balance(self) -> None: |
403 |
| - s = self.scheduler |
404 |
| - log = [] |
405 |
| - start = time() |
406 |
| - |
407 |
| - i = 0 |
408 |
| - # Paused and closing workers must never become thieves |
409 |
| - if not (potential_thieves := self.get_potential_thieves()) or not ( |
410 |
| - potential_victims := self.get_potential_victims(potential_thieves) |
411 |
| - ): |
412 |
| - return |
413 |
| - assert potential_victims |
414 |
| - assert potential_thieves |
415 |
| - for level, _ in enumerate(self.cost_multipliers): |
416 |
| - if not potential_thieves: |
417 |
| - break |
418 |
| - for victim in list(potential_victims): |
419 |
| - stealable = self.stealable[victim.address][level] |
420 |
| - if not stealable or not potential_thieves: |
421 |
| - continue |
422 |
| - for ts in list(stealable): |
423 |
| - if not potential_thieves: |
424 |
| - break |
425 |
| - if ( |
426 |
| - ts not in self.key_stealable |
427 |
| - or ts.processing_on is not victim |
428 |
| - or ts not in victim.processing |
429 |
| - ): |
430 |
| - # FIXME: Instead of discarding here, clean up stealable properly |
431 |
| - stealable.discard(ts) |
432 |
| - continue |
433 |
| - i += 1 |
434 |
| - if not (thief := _get_thief(s, ts, potential_thieves)): |
435 |
| - continue |
436 |
| - |
437 |
| - occ_thief = self._combined_occupancy(thief) |
438 |
| - occ_victim = self._combined_occupancy(victim) |
439 |
| - comm_cost_thief = self.scheduler.get_comm_cost(ts, thief) |
440 |
| - comm_cost_victim = self.scheduler.get_comm_cost(ts, victim) |
441 |
| - compute = self.scheduler.get_task_duration(ts) |
442 |
| - |
443 |
| - if ( |
444 |
| - occ_thief + comm_cost_thief + compute |
445 |
| - <= occ_victim - (comm_cost_victim + compute) / 2 |
446 |
| - ): |
447 |
| - self.move_task_request(ts, victim, thief) |
448 |
| - cost = compute + comm_cost_victim |
449 |
| - log.append( |
450 |
| - ( |
451 |
| - start, |
452 |
| - level, |
453 |
| - ts.key, |
454 |
| - cost, |
455 |
| - victim.address, |
456 |
| - occ_victim, |
457 |
| - thief.address, |
458 |
| - occ_thief, |
459 |
| - ) |
460 |
| - ) |
461 |
| - self.metrics["request_count_total"][level] += 1 |
462 |
| - self.metrics["request_cost_total"][level] += cost |
463 |
| - |
464 |
| - occ_thief = self._combined_occupancy(thief) |
465 |
| - nproc_thief = self._combined_nprocessing(thief) |
466 |
| - |
467 |
| - if not self.scheduler.is_unoccupied( |
468 |
| - thief, occ_thief, nproc_thief |
469 |
| - ): |
470 |
| - potential_thieves.discard(thief) |
471 |
| - # FIXME: move_task_request already implements some logic |
472 |
| - # for removing ts from stealable. If we made sure to |
473 |
| - # properly clean up, we would not need this |
474 |
| - stealable.discard(ts) |
475 |
| - self.scheduler.check_idle_saturated( |
476 |
| - victim, occ=self._combined_occupancy(victim) |
477 |
| - ) |
478 |
| - |
479 |
| - if log: |
480 |
| - self.log(("request", log)) |
481 |
| - self.count += 1 |
482 |
| - stop = time() |
483 |
| - if s.digests: |
484 |
| - s.digests["steal-duration"].add(stop - start) |
485 |
| - |
486 |
| - def get_potential_thieves(self) -> set[WorkerState]: |
487 |
| - s = self.scheduler |
488 |
| - potential_thieves = set(s.idle.values()) |
489 |
| - if not potential_thieves or len(potential_thieves) == len(s.workers): |
490 |
| - return set() |
491 |
| - return potential_thieves |
492 |
| - |
493 |
| - def get_potential_victims( |
494 |
| - self, potential_thieves: set[WorkerState] |
495 |
| - ) -> set[WorkerState]: |
496 |
| - s = self.scheduler |
497 |
| - potential_victims: set[WorkerState] | list[WorkerState] = s.saturated |
498 |
| - if not potential_victims: |
499 |
| - potential_victims = topk( |
500 |
| - 10, s.workers.values(), key=self._combined_occupancy |
501 |
| - ) |
502 |
| - potential_victims = [ |
503 |
| - ws |
504 |
| - for ws in potential_victims |
505 |
| - if self._combined_occupancy(ws) > 0.2 |
506 |
| - and self._combined_nprocessing(ws) > ws.nthreads |
507 |
| - and ws not in potential_thieves |
508 |
| - ] |
509 |
| - if not potential_victims: |
510 |
| - return set() |
511 |
| - if len(potential_victims) < 20: |
512 |
| - potential_victims = sorted( |
513 |
| - potential_victims, key=self._combined_occupancy, reverse=True |
514 |
| - ) |
515 |
| - return set(potential_victims) |
516 |
| - |
517 |
| - def _combined_occupancy(self, ws: WorkerState) -> float: |
518 |
| - return ws.occupancy + self.in_flight_occupancy[ws] |
519 |
| - |
520 |
| - def _combined_nprocessing(self, ws: WorkerState) -> int: |
521 |
| - return len(ws.processing) + self.in_flight_tasks[ws] |
| 547 | + def balance(self): |
| 548 | + self._balancer(self.scheduler, self).balance() |
522 | 549 |
|
523 | 550 | def restart(self, scheduler: Any) -> None:
|
524 | 551 | for stealable in self.stealable.values():
|
|
0 commit comments