@@ -1674,9 +1674,6 @@ class SchedulerState:
1674
1674
#: Subset of tasks that exist in memory on more than one worker
1675
1675
replicated_tasks : set [TaskState ]
1676
1676
1677
- #: Tasks with unknown duration, grouped by prefix
1678
- #: {task prefix: {ts, ts, ...}}
1679
- unknown_durations : dict [str , set [TaskState ]]
1680
1677
task_groups : dict [str , TaskGroup ]
1681
1678
task_prefixes : dict [str , TaskPrefix ]
1682
1679
task_metadata : dict [Key , Any ]
@@ -1776,7 +1773,6 @@ def __init__(
1776
1773
self .task_metadata = {}
1777
1774
self .total_nthreads = 0
1778
1775
self .total_nthreads_history = [(time (), 0 )]
1779
- self .unknown_durations = {}
1780
1776
self .queued = queued
1781
1777
self .unrunnable = unrunnable
1782
1778
self .validate = validate
@@ -1855,7 +1851,6 @@ def __pdict__(self) -> dict[str, Any]:
1855
1851
"unrunnable" : self .unrunnable ,
1856
1852
"queued" : self .queued ,
1857
1853
"n_tasks" : self .n_tasks ,
1858
- "unknown_durations" : self .unknown_durations ,
1859
1854
"validate" : self .validate ,
1860
1855
"tasks" : self .tasks ,
1861
1856
"task_groups" : self .task_groups ,
@@ -1907,7 +1902,6 @@ def _clear_task_state(self) -> None:
1907
1902
self .task_prefixes ,
1908
1903
self .task_groups ,
1909
1904
self .task_metadata ,
1910
- self .unknown_durations ,
1911
1905
self .replicated_tasks ,
1912
1906
):
1913
1907
collection .clear ()
@@ -1931,22 +1925,37 @@ def total_occupancy(self) -> float:
1931
1925
self ._network_occ_global ,
1932
1926
)
1933
1927
1928
+ def _get_prefix_duration (self , prefix : TaskPrefix ) -> float :
1929
+ """Get the estimated computation cost of the given task prefix
1930
+ (not including any communication cost).
1931
+
1932
+ If no data has been observed, value of
1933
+ `distributed.scheduler.default-task-durations` are used. If none is set
1934
+ for this task, `distributed.scheduler.unknown-task-duration` is used
1935
+ instead.
1936
+
1937
+ See Also
1938
+ --------
1939
+ WorkStealing.get_task_duration
1940
+ """
1941
+ # TODO: Deal with unknown tasks better
1942
+ assert prefix is not None
1943
+ duration = prefix .duration_average
1944
+ if duration < 0 :
1945
+ if prefix .max_exec_time > 0 :
1946
+ duration = 2 * prefix .max_exec_time
1947
+ else :
1948
+ duration = self .UNKNOWN_TASK_DURATION
1949
+ return duration
1950
+
1934
1951
def _calc_occupancy (
1935
1952
self ,
1936
1953
task_prefix_count : dict [str , int ],
1937
1954
network_occ : float ,
1938
1955
) -> float :
1939
1956
res = 0.0
1940
1957
for prefix_name , count in task_prefix_count .items ():
1941
- # TODO: Deal with unknown tasks better
1942
- prefix = self .task_prefixes [prefix_name ]
1943
- assert prefix is not None
1944
- duration = prefix .duration_average
1945
- if duration < 0 :
1946
- if prefix .max_exec_time > 0 :
1947
- duration = 2 * prefix .max_exec_time
1948
- else :
1949
- duration = self .UNKNOWN_TASK_DURATION
1958
+ duration = self ._get_prefix_duration (self .task_prefixes [prefix_name ])
1950
1959
res += duration * count
1951
1960
occ = res + network_occ / self .bandwidth
1952
1961
assert occ >= 0 , (occ , res , network_occ , self .bandwidth )
@@ -2536,13 +2545,6 @@ def _transition_processing_memory(
2536
2545
action = startstop ["action" ],
2537
2546
)
2538
2547
2539
- s = self .unknown_durations .pop (ts .prefix .name , set ())
2540
- steal = self .extensions .get ("stealing" )
2541
- if steal :
2542
- for tts in s :
2543
- if tts .processing_on :
2544
- steal .recalculate_cost (tts )
2545
-
2546
2548
############################
2547
2549
# Update State Information #
2548
2550
############################
@@ -3171,26 +3173,6 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float:
3171
3173
nbytes = sum (dts .nbytes for dts in deps )
3172
3174
return nbytes / self .bandwidth
3173
3175
3174
- def get_task_duration (self , ts : TaskState ) -> float :
3175
- """Get the estimated computation cost of the given task (not including
3176
- any communication cost).
3177
-
3178
- If no data has been observed, value of
3179
- `distributed.scheduler.default-task-durations` are used. If none is set
3180
- for this task, `distributed.scheduler.unknown-task-duration` is used
3181
- instead.
3182
- """
3183
- prefix = ts .prefix
3184
- duration : float = prefix .duration_average
3185
- if duration >= 0 :
3186
- return duration
3187
-
3188
- s = self .unknown_durations .get (prefix .name )
3189
- if s is None :
3190
- self .unknown_durations [prefix .name ] = s = set ()
3191
- s .add (ts )
3192
- return self .UNKNOWN_TASK_DURATION
3193
-
3194
3176
def valid_workers (self , ts : TaskState ) -> set [WorkerState ] | None :
3195
3177
"""Return set of currently valid workers for key
3196
3178
@@ -3569,20 +3551,15 @@ def _client_releases_keys(
3569
3551
elif ts .state != "erred" and not ts .waiters :
3570
3552
recommendations [ts .key ] = "released"
3571
3553
3572
- def _task_to_msg (self , ts : TaskState , duration : float = - 1 ) -> dict [str , Any ]:
3554
+ def _task_to_msg (self , ts : TaskState ) -> dict [str , Any ]:
3573
3555
"""Convert a single computational task to a message"""
3574
- # FIXME: The duration attribute is not used on worker. We could save ourselves the
3575
- # time to compute and submit this
3576
- if duration < 0 :
3577
- duration = self .get_task_duration (ts )
3578
3556
ts .run_id = next (TaskState ._run_id_iterator )
3579
3557
assert ts .priority , ts
3580
3558
msg : dict [str , Any ] = {
3581
3559
"op" : "compute-task" ,
3582
3560
"key" : ts .key ,
3583
3561
"run_id" : ts .run_id ,
3584
3562
"priority" : ts .priority ,
3585
- "duration" : duration ,
3586
3563
"stimulus_id" : f"compute-task-{ time ()} " ,
3587
3564
"who_has" : {
3588
3565
dts .key : tuple (ws .address for ws in (dts .who_has or ()))
@@ -6003,12 +5980,10 @@ async def remove_client_from_events() -> None:
6003
5980
cleanup_delay , remove_client_from_events
6004
5981
)
6005
5982
6006
- def send_task_to_worker (
6007
- self , worker : str , ts : TaskState , duration : float = - 1
6008
- ) -> None :
5983
+ def send_task_to_worker (self , worker : str , ts : TaskState ) -> None :
6009
5984
"""Send a single computational task to a worker"""
6010
5985
try :
6011
- msg = self ._task_to_msg (ts , duration )
5986
+ msg = self ._task_to_msg (ts )
6012
5987
self .worker_send (worker , msg )
6013
5988
except Exception as e :
6014
5989
logger .exception (e )
@@ -8859,10 +8834,7 @@ def adaptive_target(self, target_duration=None):
8859
8834
queued = take (100 , concat ([self .queued , self .unrunnable .keys ()]))
8860
8835
queued_occupancy = 0
8861
8836
for ts in queued :
8862
- if ts .prefix .duration_average == - 1 :
8863
- queued_occupancy += self .UNKNOWN_TASK_DURATION
8864
- else :
8865
- queued_occupancy += ts .prefix .duration_average
8837
+ queued_occupancy += self ._get_prefix_duration (ts .prefix )
8866
8838
8867
8839
tasks_ready = len (self .queued ) + len (self .unrunnable )
8868
8840
if tasks_ready > 100 :
0 commit comments