@@ -501,7 +501,7 @@ class WorkerState:
501
501
# Reference to scheduler task_groups
502
502
scheduler_ref : weakref .ref [SchedulerState ] | None
503
503
task_prefix_count : defaultdict [str , int ]
504
- _network_occ : float
504
+ _network_occ : int
505
505
_occupancy_cache : float | None
506
506
507
507
#: Keys that may need to be fetched to this worker, and the number of tasks that need them.
@@ -822,8 +822,11 @@ def _dec_needs_replica(self, ts: TaskState) -> None:
822
822
if self .needs_what [ts ] == 0 :
823
823
del self .needs_what [ts ]
824
824
nbytes = ts .get_nbytes ()
825
- self ._network_occ -= nbytes
826
- self .scheduler ._network_occ_global -= nbytes
825
+ # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift
826
+ self ._network_occ -= min (nbytes , self ._network_occ )
827
+ self .scheduler ._network_occ_global -= min (
828
+ nbytes , self .scheduler ._network_occ_global
829
+ )
827
830
828
831
def add_replica (self , ts : TaskState ) -> None :
829
832
"""The worker acquired a replica of task"""
@@ -834,8 +837,11 @@ def add_replica(self, ts: TaskState) -> None:
834
837
nbytes = ts .get_nbytes ()
835
838
if ts in self .needs_what :
836
839
del self .needs_what [ts ]
837
- self ._network_occ -= nbytes
838
- self .scheduler ._network_occ_global -= nbytes
840
+ # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift
841
+ self ._network_occ -= min (nbytes , self ._network_occ )
842
+ self .scheduler ._network_occ_global -= min (
843
+ nbytes , self .scheduler ._network_occ_global
844
+ )
839
845
ts .who_has .add (self )
840
846
self .nbytes += nbytes
841
847
self ._has_what [ts ] = None
@@ -1708,7 +1714,7 @@ class SchedulerState:
1708
1714
transition_counter_max : int | Literal [False ]
1709
1715
1710
1716
_task_prefix_count_global : defaultdict [str , int ]
1711
- _network_occ_global : float
1717
+ _network_occ_global : int
1712
1718
######################
1713
1719
# Cached configuration
1714
1720
######################
@@ -1777,7 +1783,7 @@ def __init__(
1777
1783
self .validate = validate
1778
1784
self .workers = workers
1779
1785
self ._task_prefix_count_global = defaultdict (int )
1780
- self ._network_occ_global = 0.0
1786
+ self ._network_occ_global = 0
1781
1787
self .running = {
1782
1788
ws for ws in self .workers .values () if ws .status == Status .running
1783
1789
}
@@ -1957,7 +1963,9 @@ def _calc_occupancy(
1957
1963
duration = self ._get_prefix_duration (self .task_prefixes [prefix_name ])
1958
1964
res += duration * count
1959
1965
occ = res + network_occ / self .bandwidth
1960
- assert occ >= 0 , (occ , res , network_occ , self .bandwidth )
1966
+ if self .validate :
1967
+ assert occ >= 0 , (occ , res , network_occ , self .bandwidth )
1968
+ occ = max (occ , 0 )
1961
1969
return occ
1962
1970
1963
1971
#####################
0 commit comments