Skip to content

Commit bbdd2ee

Browse files
authoredJan 16, 2025··
Use IO task marker in scheduling (#8950)
1 parent f28498e commit bbdd2ee

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed
 

‎distributed/scheduler.py

+3
Original file line numberDiff line numberDiff line change
@@ -3085,6 +3085,9 @@ def is_rootish(self, ts: TaskState) -> bool:
30853085
"""
30863086
if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions:
30873087
return False
3088+
# Check explicitly marked data producer tasks
3089+
if ts.run_spec and ts.run_spec.data_producer:
3090+
return True
30883091
tg = ts.group
30893092
# TODO short-circuit to True if `not ts.dependencies`?
30903093
return (

‎distributed/tests/test_scheduler.py

+32
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import dask
2626
from dask import bag, delayed
27+
from dask.base import DaskMethodsMixin
2728
from dask.core import flatten
2829
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
2930
from dask.utils import parse_timedelta, tmpfile, typename
@@ -5315,3 +5316,34 @@ async def test_alias_resolving_break_queuing(c, s, a):
53155316
while not s.tasks:
53165317
await asyncio.sleep(0.01)
53175318
assert sum([s.is_rootish(v) for v in s.tasks.values()]) == 18
5319+
5320+
5321+
@gen_cluster(client=True, nthreads=[("", 1)])
5322+
async def test_data_producers(c, s, a):
5323+
from dask._task_spec import DataNode, Task, TaskRef
5324+
5325+
def func(*args):
5326+
return 100
5327+
5328+
class MyArray(DaskMethodsMixin):
5329+
def __dask_graph__(self):
5330+
return {
5331+
"a": DataNode("a", 10),
5332+
"b": Task("b", func, TaskRef("a"), _data_producer=True),
5333+
"c": Task("c", func, TaskRef("b")),
5334+
"d": Task("d", func, TaskRef("c")),
5335+
}
5336+
5337+
def __dask_keys__(self):
5338+
return ["d"]
5339+
5340+
def __dask_postcompute__(self):
5341+
return func, ()
5342+
5343+
arr = MyArray()
5344+
x = c.compute(arr)
5345+
await async_poll_for(lambda: s.tasks, 5)
5346+
assert (
5347+
sum([s.is_rootish(v) and v.run_spec.data_producer for v in s.tasks.values()])
5348+
== 2
5349+
)

0 commit comments

Comments
 (0)
Please sign in to comment.