Skip to content

Commit b86b714

Browse files
Consistent estimation of task duration between stealing, adaptive and occupancy calculation (#9000)
1 parent 5589049 commit b86b714

6 files changed

+141
-91
lines changed

distributed/scheduler.py

+28-56
Original file line numberDiff line numberDiff line change
@@ -1674,9 +1674,6 @@ class SchedulerState:
16741674
#: Subset of tasks that exist in memory on more than one worker
16751675
replicated_tasks: set[TaskState]
16761676

1677-
#: Tasks with unknown duration, grouped by prefix
1678-
#: {task prefix: {ts, ts, ...}}
1679-
unknown_durations: dict[str, set[TaskState]]
16801677
task_groups: dict[str, TaskGroup]
16811678
task_prefixes: dict[str, TaskPrefix]
16821679
task_metadata: dict[Key, Any]
@@ -1776,7 +1773,6 @@ def __init__(
17761773
self.task_metadata = {}
17771774
self.total_nthreads = 0
17781775
self.total_nthreads_history = [(time(), 0)]
1779-
self.unknown_durations = {}
17801776
self.queued = queued
17811777
self.unrunnable = unrunnable
17821778
self.validate = validate
@@ -1855,7 +1851,6 @@ def __pdict__(self) -> dict[str, Any]:
18551851
"unrunnable": self.unrunnable,
18561852
"queued": self.queued,
18571853
"n_tasks": self.n_tasks,
1858-
"unknown_durations": self.unknown_durations,
18591854
"validate": self.validate,
18601855
"tasks": self.tasks,
18611856
"task_groups": self.task_groups,
@@ -1907,7 +1902,6 @@ def _clear_task_state(self) -> None:
19071902
self.task_prefixes,
19081903
self.task_groups,
19091904
self.task_metadata,
1910-
self.unknown_durations,
19111905
self.replicated_tasks,
19121906
):
19131907
collection.clear()
@@ -1931,22 +1925,37 @@ def total_occupancy(self) -> float:
19311925
self._network_occ_global,
19321926
)
19331927

1928+
def _get_prefix_duration(self, prefix: TaskPrefix) -> float:
1929+
"""Get the estimated computation cost of the given task prefix
1930+
(not including any communication cost).
1931+
1932+
If no data has been observed, value of
1933+
`distributed.scheduler.default-task-durations` are used. If none is set
1934+
for this task, `distributed.scheduler.unknown-task-duration` is used
1935+
instead.
1936+
1937+
See Also
1938+
--------
1939+
WorkStealing.get_task_duration
1940+
"""
1941+
# TODO: Deal with unknown tasks better
1942+
assert prefix is not None
1943+
duration = prefix.duration_average
1944+
if duration < 0:
1945+
if prefix.max_exec_time > 0:
1946+
duration = 2 * prefix.max_exec_time
1947+
else:
1948+
duration = self.UNKNOWN_TASK_DURATION
1949+
return duration
1950+
19341951
def _calc_occupancy(
19351952
self,
19361953
task_prefix_count: dict[str, int],
19371954
network_occ: float,
19381955
) -> float:
19391956
res = 0.0
19401957
for prefix_name, count in task_prefix_count.items():
1941-
# TODO: Deal with unknown tasks better
1942-
prefix = self.task_prefixes[prefix_name]
1943-
assert prefix is not None
1944-
duration = prefix.duration_average
1945-
if duration < 0:
1946-
if prefix.max_exec_time > 0:
1947-
duration = 2 * prefix.max_exec_time
1948-
else:
1949-
duration = self.UNKNOWN_TASK_DURATION
1958+
duration = self._get_prefix_duration(self.task_prefixes[prefix_name])
19501959
res += duration * count
19511960
occ = res + network_occ / self.bandwidth
19521961
assert occ >= 0, (occ, res, network_occ, self.bandwidth)
@@ -2536,13 +2545,6 @@ def _transition_processing_memory(
25362545
action=startstop["action"],
25372546
)
25382547

2539-
s = self.unknown_durations.pop(ts.prefix.name, set())
2540-
steal = self.extensions.get("stealing")
2541-
if steal:
2542-
for tts in s:
2543-
if tts.processing_on:
2544-
steal.recalculate_cost(tts)
2545-
25462548
############################
25472549
# Update State Information #
25482550
############################
@@ -3171,26 +3173,6 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float:
31713173
nbytes = sum(dts.nbytes for dts in deps)
31723174
return nbytes / self.bandwidth
31733175

3174-
def get_task_duration(self, ts: TaskState) -> float:
3175-
"""Get the estimated computation cost of the given task (not including
3176-
any communication cost).
3177-
3178-
If no data has been observed, value of
3179-
`distributed.scheduler.default-task-durations` are used. If none is set
3180-
for this task, `distributed.scheduler.unknown-task-duration` is used
3181-
instead.
3182-
"""
3183-
prefix = ts.prefix
3184-
duration: float = prefix.duration_average
3185-
if duration >= 0:
3186-
return duration
3187-
3188-
s = self.unknown_durations.get(prefix.name)
3189-
if s is None:
3190-
self.unknown_durations[prefix.name] = s = set()
3191-
s.add(ts)
3192-
return self.UNKNOWN_TASK_DURATION
3193-
31943176
def valid_workers(self, ts: TaskState) -> set[WorkerState] | None:
31953177
"""Return set of currently valid workers for key
31963178
@@ -3569,20 +3551,15 @@ def _client_releases_keys(
35693551
elif ts.state != "erred" and not ts.waiters:
35703552
recommendations[ts.key] = "released"
35713553

3572-
def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
3554+
def _task_to_msg(self, ts: TaskState) -> dict[str, Any]:
35733555
"""Convert a single computational task to a message"""
3574-
# FIXME: The duration attribute is not used on worker. We could save ourselves the
3575-
# time to compute and submit this
3576-
if duration < 0:
3577-
duration = self.get_task_duration(ts)
35783556
ts.run_id = next(TaskState._run_id_iterator)
35793557
assert ts.priority, ts
35803558
msg: dict[str, Any] = {
35813559
"op": "compute-task",
35823560
"key": ts.key,
35833561
"run_id": ts.run_id,
35843562
"priority": ts.priority,
3585-
"duration": duration,
35863563
"stimulus_id": f"compute-task-{time()}",
35873564
"who_has": {
35883565
dts.key: tuple(ws.address for ws in (dts.who_has or ()))
@@ -6003,12 +5980,10 @@ async def remove_client_from_events() -> None:
60035980
cleanup_delay, remove_client_from_events
60045981
)
60055982

6006-
def send_task_to_worker(
6007-
self, worker: str, ts: TaskState, duration: float = -1
6008-
) -> None:
5983+
def send_task_to_worker(self, worker: str, ts: TaskState) -> None:
60095984
"""Send a single computational task to a worker"""
60105985
try:
6011-
msg = self._task_to_msg(ts, duration)
5986+
msg = self._task_to_msg(ts)
60125987
self.worker_send(worker, msg)
60135988
except Exception as e:
60145989
logger.exception(e)
@@ -8859,10 +8834,7 @@ def adaptive_target(self, target_duration=None):
88598834
queued = take(100, concat([self.queued, self.unrunnable.keys()]))
88608835
queued_occupancy = 0
88618836
for ts in queued:
8862-
if ts.prefix.duration_average == -1:
8863-
queued_occupancy += self.UNKNOWN_TASK_DURATION
8864-
else:
8865-
queued_occupancy += ts.prefix.duration_average
8837+
queued_occupancy += self._get_prefix_duration(ts.prefix)
88668838

88678839
tasks_ready = len(self.queued) + len(self.unrunnable)
88688840
if tasks_ready > 100:

distributed/stealing.py

+40-16
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class WorkStealing(SchedulerPlugin):
8888
metrics: dict[str, dict[int, float]]
8989
_in_flight_event: asyncio.Event
9090
_request_counter: int
91+
#: Tasks with unknown duration, grouped by prefix
92+
#: {task prefix: {ts, ts, ...}}
93+
unknown_durations: dict[str, set[TaskState]]
9194

9295
def __init__(self, scheduler: Scheduler):
9396
self.scheduler = scheduler
@@ -111,6 +114,7 @@ def __init__(self, scheduler: Scheduler):
111114
self.in_flight_occupancy = defaultdict(int)
112115
self.in_flight_tasks = defaultdict(int)
113116
self._in_flight_event = asyncio.Event()
117+
self.unknown_durations = {}
114118
self.metrics = {
115119
"request_count_total": defaultdict(int),
116120
"request_cost_total": defaultdict(int),
@@ -188,6 +192,13 @@ def transition(
188192
ts = self.scheduler.tasks[key]
189193
self.remove_key_from_stealable(ts)
190194
self._remove_from_in_flight(ts)
195+
196+
if finish == "memory":
197+
s = self.unknown_durations.pop(ts.prefix.name, set())
198+
for tts in s:
199+
if tts.processing_on:
200+
self.recalculate_cost(tts)
201+
191202
if finish == "processing":
192203
ts = self.scheduler.tasks[key]
193204
self.put_key_in_stealable(ts)
@@ -223,13 +234,27 @@ def recalculate_cost(self, ts: TaskState) -> None:
223234

224235
def put_key_in_stealable(self, ts: TaskState) -> None:
225236
cost_multiplier, level = self.steal_time_ratio(ts)
226-
if cost_multiplier is not None:
227-
assert level is not None
228-
assert ts.processing_on
229-
ws = ts.processing_on
230-
worker = ws.address
231-
self.stealable[worker][level].add(ts)
232-
self.key_stealable[ts] = (worker, level)
237+
238+
if cost_multiplier is None:
239+
return
240+
241+
prefix = ts.prefix
242+
duration = self.scheduler._get_prefix_duration(prefix)
243+
244+
assert level is not None
245+
assert ts.processing_on
246+
ws = ts.processing_on
247+
worker = ws.address
248+
self.stealable[worker][level].add(ts)
249+
self.key_stealable[ts] = (worker, level)
250+
251+
if duration == ts.prefix.duration_average:
252+
return
253+
254+
if prefix.name not in self.unknown_durations:
255+
self.unknown_durations[prefix.name] = set()
256+
257+
self.unknown_durations[prefix.name].add(ts)
233258

234259
def remove_key_from_stealable(self, ts: TaskState) -> None:
235260
result = self.key_stealable.pop(ts, None)
@@ -255,7 +280,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non
255280
if not ts.dependencies: # no dependencies fast path
256281
return 0, 0
257282

258-
compute_time = self.scheduler.get_task_duration(ts)
283+
compute_time = self.scheduler._get_prefix_duration(ts.prefix)
259284

260285
if not compute_time:
261286
# occupancy/ws.processing[ts] is only allowed to be zero for
@@ -301,12 +326,9 @@ def move_task_request(
301326

302327
# TODO: occupancy no longer concats linearly so we can't easily
303328
# assume that the network cost would go down by that much
304-
victim_duration = self.scheduler.get_task_duration(
305-
ts
306-
) + self.scheduler.get_comm_cost(ts, victim)
307-
thief_duration = self.scheduler.get_task_duration(
308-
ts
309-
) + self.scheduler.get_comm_cost(ts, thief)
329+
compute = self.scheduler._get_prefix_duration(ts.prefix)
330+
victim_duration = compute + self.scheduler.get_comm_cost(ts, victim)
331+
thief_duration = compute + self.scheduler.get_comm_cost(ts, thief)
310332

311333
self.scheduler.stream_comms[victim.address].send(
312334
{"op": "steal-request", "key": key, "stimulus_id": stimulus_id}
@@ -457,8 +479,7 @@ def balance(self) -> None:
457479
occ_victim = self._combined_occupancy(victim)
458480
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
459481
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
460-
compute = self.scheduler.get_task_duration(ts)
461-
482+
compute = self.scheduler._get_prefix_duration(ts.prefix)
462483
if (
463484
occ_thief + comm_cost_thief + compute
464485
<= occ_victim - (comm_cost_victim + compute) / 2
@@ -483,6 +504,8 @@ def balance(self) -> None:
483504
occ_thief = self._combined_occupancy(thief)
484505
nproc_thief = self._combined_nprocessing(thief)
485506

507+
# FIXME: In the worst case, the victim may have 3x the amount of work
508+
# of the thief when this aborts balancing.
486509
if not self.scheduler.is_unoccupied(
487510
thief, occ_thief, nproc_thief
488511
):
@@ -514,6 +537,7 @@ def restart(self, scheduler: Any) -> None:
514537
s.clear()
515538

516539
self.key_stealable.clear()
540+
self.unknown_durations.clear()
517541

518542
def story(self, *keys_or_ts: str | TaskState) -> list:
519543
keys = {key.key if not isinstance(key, str) else key for key in keys_or_ts}

distributed/tests/test_scheduler.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -2788,24 +2788,26 @@ async def test_retire_workers_bad_params(c, s, a, b):
27882788
@gen_cluster(
27892789
client=True, config={"distributed.scheduler.default-task-durations": {"inc": 100}}
27902790
)
2791-
async def test_get_task_duration(c, s, a, b):
2791+
async def test_get_prefix_duration(c, s, a, b):
27922792
future = c.submit(inc, 1)
27932793
await future
27942794
assert 10 < s.task_prefixes["inc"].duration_average < 100
27952795

27962796
ts_pref1 = s.new_task("inc-abcdefab", None, "released")
2797-
assert 10 < s.get_task_duration(ts_pref1) < 100
2797+
assert 10 < s._get_prefix_duration(ts_pref1.prefix) < 100
27982798

2799+
extension = s.extensions["stealing"]
27992800
# make sure get_task_duration adds TaskStates to unknown dict
2800-
assert len(s.unknown_durations) == 0
2801+
assert len(extension.unknown_durations) == 0
28012802
x = c.submit(slowinc, 1, delay=0.5)
28022803
while len(s.tasks) < 3:
28032804
await asyncio.sleep(0.01)
28042805

28052806
ts = s.tasks[x.key]
2806-
assert s.get_task_duration(ts) == 0.5 # default
2807-
assert len(s.unknown_durations) == 1
2808-
assert len(s.unknown_durations["slowinc"]) == 1
2807+
assert s._get_prefix_duration(ts.prefix) == 0.5 # default
2808+
2809+
assert len(extension.unknown_durations) == 1
2810+
assert len(extension.unknown_durations["slowinc"]) == 1
28092811

28102812

28112813
@gen_cluster(client=True)
@@ -3338,10 +3340,11 @@ async def test_unknown_task_duration_config(client, s, a, b):
33383340
future = client.submit(slowinc, 1)
33393341
while not s.tasks:
33403342
await asyncio.sleep(0.001)
3341-
assert sum(s.get_task_duration(ts) for ts in s.tasks.values()) == 3600
3342-
assert len(s.unknown_durations) == 1
3343+
assert sum(s._get_prefix_duration(ts.prefix) for ts in s.tasks.values()) == 3600
3344+
extension = s.extensions["stealing"]
3345+
assert len(extension.unknown_durations) == 1
33433346
await wait(future)
3344-
assert len(s.unknown_durations) == 0
3347+
assert len(extension.unknown_durations) == 0
33453348

33463349

33473350
@gen_cluster()

0 commit comments

Comments
 (0)