Skip to content

Commit 8efe8fb

Browse files
authoredFeb 13, 2025··
Hotfix: Ignore negative occupancy (#9012)
1 parent d156473 commit 8efe8fb

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed
 

‎distributed/scheduler.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ class WorkerState:
501501
# Reference to scheduler task_groups
502502
scheduler_ref: weakref.ref[SchedulerState] | None
503503
task_prefix_count: defaultdict[str, int]
504-
_network_occ: float
504+
_network_occ: int
505505
_occupancy_cache: float | None
506506

507507
#: Keys that may need to be fetched to this worker, and the number of tasks that need them.
@@ -822,8 +822,11 @@ def _dec_needs_replica(self, ts: TaskState) -> None:
822822
if self.needs_what[ts] == 0:
823823
del self.needs_what[ts]
824824
nbytes = ts.get_nbytes()
825-
self._network_occ -= nbytes
826-
self.scheduler._network_occ_global -= nbytes
825+
# FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift
826+
self._network_occ -= min(nbytes, self._network_occ)
827+
self.scheduler._network_occ_global -= min(
828+
nbytes, self.scheduler._network_occ_global
829+
)
827830

828831
def add_replica(self, ts: TaskState) -> None:
829832
"""The worker acquired a replica of task"""
@@ -834,8 +837,11 @@ def add_replica(self, ts: TaskState) -> None:
834837
nbytes = ts.get_nbytes()
835838
if ts in self.needs_what:
836839
del self.needs_what[ts]
837-
self._network_occ -= nbytes
838-
self.scheduler._network_occ_global -= nbytes
840+
# FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift
841+
self._network_occ -= min(nbytes, self._network_occ)
842+
self.scheduler._network_occ_global -= min(
843+
nbytes, self.scheduler._network_occ_global
844+
)
839845
ts.who_has.add(self)
840846
self.nbytes += nbytes
841847
self._has_what[ts] = None
@@ -1708,7 +1714,7 @@ class SchedulerState:
17081714
transition_counter_max: int | Literal[False]
17091715

17101716
_task_prefix_count_global: defaultdict[str, int]
1711-
_network_occ_global: float
1717+
_network_occ_global: int
17121718
######################
17131719
# Cached configuration
17141720
######################
@@ -1777,7 +1783,7 @@ def __init__(
17771783
self.validate = validate
17781784
self.workers = workers
17791785
self._task_prefix_count_global = defaultdict(int)
1780-
self._network_occ_global = 0.0
1786+
self._network_occ_global = 0
17811787
self.running = {
17821788
ws for ws in self.workers.values() if ws.status == Status.running
17831789
}
@@ -1957,7 +1963,9 @@ def _calc_occupancy(
19571963
duration = self._get_prefix_duration(self.task_prefixes[prefix_name])
19581964
res += duration * count
19591965
occ = res + network_occ / self.bandwidth
1960-
assert occ >= 0, (occ, res, network_occ, self.bandwidth)
1966+
if self.validate:
1967+
assert occ >= 0, (occ, res, network_occ, self.bandwidth)
1968+
occ = max(occ, 0)
19611969
return occ
19621970

19631971
#####################

0 commit comments

Comments
 (0)
Please sign in to comment.