|
24 | 24 |
|
25 | 25 | import dask
|
26 | 26 | from dask import bag, delayed
|
| 27 | +from dask.base import DaskMethodsMixin |
27 | 28 | from dask.core import flatten
|
28 | 29 | from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
|
29 | 30 | from dask.utils import parse_timedelta, tmpfile, typename
|
@@ -5315,3 +5316,34 @@ async def test_alias_resolving_break_queuing(c, s, a):
|
5315 | 5316 | while not s.tasks:
|
5316 | 5317 | await asyncio.sleep(0.01)
|
5317 | 5318 | assert sum([s.is_rootish(v) for v in s.tasks.values()]) == 18
|
| 5319 | + |
| 5320 | + |
| 5321 | +@gen_cluster(client=True, nthreads=[("", 1)]) |
| 5322 | +async def test_data_producers(c, s, a): |
| 5323 | + from dask._task_spec import DataNode, Task, TaskRef |
| 5324 | + |
| 5325 | + def func(*args): |
| 5326 | + return 100 |
| 5327 | + |
| 5328 | + class MyArray(DaskMethodsMixin): |
| 5329 | + def __dask_graph__(self): |
| 5330 | + return { |
| 5331 | + "a": DataNode("a", 10), |
| 5332 | + "b": Task("b", func, TaskRef("a"), _data_producer=True), |
| 5333 | + "c": Task("c", func, TaskRef("b")), |
| 5334 | + "d": Task("d", func, TaskRef("c")), |
| 5335 | + } |
| 5336 | + |
| 5337 | + def __dask_keys__(self): |
| 5338 | + return ["d"] |
| 5339 | + |
| 5340 | + def __dask_postcompute__(self): |
| 5341 | + return func, () |
| 5342 | + |
| 5343 | + arr = MyArray() |
| 5344 | + x = c.compute(arr) |
| 5345 | + await async_poll_for(lambda: s.tasks, 5) |
| 5346 | + assert ( |
| 5347 | + sum([s.is_rootish(v) and v.run_spec.data_producer for v in s.tasks.values()]) |
| 5348 | + == 2 |
| 5349 | + ) |
0 commit comments