Skip to content

Commit 54fdd21

Browse files
committed
Refactor for maintainability
1 parent 49b916d commit 54fdd21

File tree

2 files changed

+49
-49
lines changed

2 files changed

+49
-49
lines changed

distributed/scheduler.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -1925,22 +1925,37 @@ def total_occupancy(self) -> float:
19251925
self._network_occ_global,
19261926
)
19271927

1928+
def _get_prefix_duration(self, prefix: TaskPrefix) -> float:
1929+
"""Get the estimated computation cost of the given task prefix
1930+
(not including any communication cost).
1931+
1932+
If no data has been observed, value of
1933+
`distributed.scheduler.default-task-durations` are used. If none is set
1934+
for this task, `distributed.scheduler.unknown-task-duration` is used
1935+
instead.
1936+
1937+
See Also
1938+
--------
1939+
WorkStealing.get_task_duration
1940+
"""
1941+
# TODO: Deal with unknown tasks better
1942+
assert prefix is not None
1943+
duration = prefix.duration_average
1944+
if duration < 0:
1945+
if prefix.max_exec_time > 0:
1946+
duration = 2 * prefix.max_exec_time
1947+
else:
1948+
duration = self.UNKNOWN_TASK_DURATION
1949+
return duration
1950+
19281951
def _calc_occupancy(
19291952
self,
19301953
task_prefix_count: dict[str, int],
19311954
network_occ: float,
19321955
) -> float:
19331956
res = 0.0
19341957
for prefix_name, count in task_prefix_count.items():
1935-
# TODO: Deal with unknown tasks better
1936-
prefix = self.task_prefixes[prefix_name]
1937-
assert prefix is not None
1938-
duration = prefix.duration_average
1939-
if duration < 0:
1940-
if prefix.max_exec_time > 0:
1941-
duration = 2 * prefix.max_exec_time
1942-
else:
1943-
duration = self.UNKNOWN_TASK_DURATION
1958+
duration = self._get_prefix_duration(self.task_prefixes[prefix_name])
19441959
res += duration * count
19451960
occ = res + network_occ / self.bandwidth
19461961
assert occ >= 0, (occ, res, network_occ, self.bandwidth)

distributed/stealing.py

+25-40
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,26 @@ def recalculate_cost(self, ts: TaskState) -> None:
234234

235235
def put_key_in_stealable(self, ts: TaskState) -> None:
236236
cost_multiplier, level = self.steal_time_ratio(ts)
237-
if cost_multiplier is not None:
238-
assert level is not None
239-
assert ts.processing_on
240-
ws = ts.processing_on
241-
worker = ws.address
242-
self.stealable[worker][level].add(ts)
243-
self.key_stealable[ts] = (worker, level)
237+
238+
prefix = ts.prefix
239+
duration = self.scheduler._get_prefix_duration(prefix)
240+
if cost_multiplier is None:
241+
return
242+
243+
assert level is not None
244+
assert ts.processing_on
245+
ws = ts.processing_on
246+
worker = ws.address
247+
self.stealable[worker][level].add(ts)
248+
self.key_stealable[ts] = (worker, level)
249+
250+
if duration == ts.prefix.duration_average:
251+
return
252+
253+
s = self.unknown_durations.get(prefix.name)
254+
if s is None:
255+
self.unknown_durations[prefix.name] = s = set()
256+
s.add(ts)
244257

245258
def remove_key_from_stealable(self, ts: TaskState) -> None:
246259
result = self.key_stealable.pop(ts, None)
@@ -266,7 +279,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non
266279
if not ts.dependencies: # no dependencies fast path
267280
return 0, 0
268281

269-
compute_time = self.get_task_duration(ts)
282+
compute_time = self.scheduler._get_prefix_duration(ts.prefix)
270283

271284
if not compute_time:
272285
# occupancy/ws.processing[ts] is only allowed to be zero for
@@ -312,12 +325,9 @@ def move_task_request(
312325

313326
# TODO: occupancy no longer concats linearly so we can't easily
314327
# assume that the network cost would go down by that much
315-
victim_duration = self.get_task_duration(ts) + self.scheduler.get_comm_cost(
316-
ts, victim
317-
)
318-
thief_duration = self.get_task_duration(ts) + self.scheduler.get_comm_cost(
319-
ts, thief
320-
)
328+
compute = self.scheduler._get_prefix_duration(ts.prefix)
329+
victim_duration = compute + self.scheduler.get_comm_cost(ts, victim)
330+
thief_duration = compute + self.scheduler.get_comm_cost(ts, thief)
321331

322332
self.scheduler.stream_comms[victim.address].send(
323333
{"op": "steal-request", "key": key, "stimulus_id": stimulus_id}
@@ -468,8 +478,7 @@ def balance(self) -> None:
468478
occ_victim = self._combined_occupancy(victim)
469479
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
470480
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
471-
compute = self.get_task_duration(ts)
472-
481+
compute = self.scheduler._get_prefix_duration(ts.prefix)
473482
if (
474483
occ_thief + comm_cost_thief + compute
475484
<= occ_victim - (comm_cost_victim + compute) / 2
@@ -542,30 +551,6 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
542551
out.append(t)
543552
return out
544553

545-
def get_task_duration(self, ts: TaskState) -> float:
546-
"""Get the estimated computation cost of the given task (not including
547-
any communication cost).
548-
549-
If no data has been observed, value of
550-
`distributed.scheduler.default-task-durations` are used. If none is set
551-
for this task, `distributed.scheduler.unknown-task-duration` is used
552-
instead.
553-
"""
554-
prefix = ts.prefix
555-
duration: float = prefix.duration_average
556-
if duration >= 0:
557-
return duration
558-
if prefix.max_exec_time > 0:
559-
duration = 2 * prefix.max_exec_time
560-
else:
561-
duration = self.scheduler.UNKNOWN_TASK_DURATION
562-
563-
s = self.unknown_durations.get(prefix.name)
564-
if s is None:
565-
self.unknown_durations[prefix.name] = s = set()
566-
s.add(ts)
567-
return duration
568-
569554

570555
def _get_thief(
571556
scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState]

0 commit comments

Comments
 (0)