Skip to content

Commit 48cd0f1

Browse files
committed
fix annotations
1 parent ce62eee commit 48cd0f1

File tree

3 files changed

+83
-98
lines changed

3 files changed

+83
-98
lines changed

distributed/client.py

+58-90
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@
3939

4040
from packaging.version import parse as parse_version
4141
from tlz import first, groupby, merge, partition_all, valmap
42+
from tornado import gen
43+
from tornado.ioloop import IOLoop
4244

4345
import dask
46+
from dask._expr import Expr, HLGExpr, LLGExpr
47+
from dask._task_spec import DataNode, GraphNode, List, Task, TaskRef, parse_input
4448
from dask.base import collections_to_dsk
4549
from dask.core import flatten, validate_key
46-
from dask.layers import Layer
50+
from dask.highlevelgraph import HighLevelGraph
4751
from dask.tokenize import tokenize
4852
from dask.typing import Key, NestedKeys, NoDefault, no_default
4953
from dask.utils import (
@@ -57,19 +61,6 @@
5761
)
5862
from dask.widgets import get_template
5963

60-
from distributed.core import OKMessage
61-
from distributed.protocol.serialize import _is_dumpable
62-
from distributed.utils import Deadline, wait_for
63-
64-
try:
65-
from dask.delayed import single_key
66-
except ImportError:
67-
single_key = first
68-
from tornado import gen
69-
from tornado.ioloop import IOLoop
70-
71-
from dask._task_spec import DataNode, GraphNode, List, Task, TaskRef, parse_input
72-
7364
import distributed.utils
7465
from distributed import cluster_dump, preloading
7566
from distributed import versions as version_module
@@ -79,6 +70,7 @@
7970
from distributed.core import (
8071
CommClosedError,
8172
ConnectionPool,
73+
OKMessage,
8274
PooledRPCCall,
8375
Status,
8476
clean_exception,
@@ -98,6 +90,7 @@
9890
from distributed.objects import HasWhat, SchedulerInfo, WhoHas
9991
from distributed.protocol import to_serialize
10092
from distributed.protocol.pickle import dumps, loads
93+
from distributed.protocol.serialize import _is_dumpable
10194
from distributed.publish import Datasets
10295
from distributed.pubsub import PubSubClientExtension
10396
from distributed.security import Security
@@ -106,6 +99,7 @@
10699
from distributed.threadpoolexecutor import rejoin
107100
from distributed.utils import (
108101
CancelledError,
102+
Deadline,
109103
LoopRunner,
110104
NoOpAwaitable,
111105
SyncMethodMixin,
@@ -117,6 +111,7 @@
117111
nbytes,
118112
sync,
119113
thread_state,
114+
wait_for,
120115
)
121116
from distributed.utils_comm import (
122117
gather_from_workers,
@@ -834,51 +829,32 @@ def _is_nested(iterable):
834829
return False
835830

836831

837-
class _MapLayer(Layer):
832+
class _MapExpr(Expr):
838833
func: Callable
839-
iterables: Iterable[Any]
840-
key: str | Iterable[str] | None
834+
iterables: Iterable
835+
key: Key
841836
pure: bool
842-
annotations: dict[str, Any] | None
843-
844-
def __init__(
845-
self,
846-
func: Callable,
847-
iterables: Iterable[Any],
848-
key: str | Iterable[str] | None = None,
849-
pure: bool = True,
850-
annotations: dict[str, Any] | None = None,
851-
**kwargs,
852-
):
853-
self.func: Callable = func
854-
self.iterables = [tuple(map(parse_input, iterable)) for iterable in iterables]
855-
self.key: str | Iterable[str] | None = key
856-
self.pure: bool = pure
857-
self.kwargs = {k: parse_input(v) for k, v in kwargs.items()}
858-
super().__init__(annotations=annotations)
859-
860-
def __repr__(self) -> str:
861-
return f"{type(self).__name__} <func='{funcname(self.func)}'>"
837+
annotations: dict
838+
kwargs: dict
839+
_cached_keys: Iterable[Key] | None
840+
_parameters = [
841+
"func",
842+
"iterables",
843+
"key",
844+
"pure",
845+
"annotations",
846+
"kwargs",
847+
"_cached_keys",
848+
]
849+
_defaults = {"_cached_keys": None}
862850

863851
@property
864-
def _dict(self) -> _T_LowLevelGraph:
865-
self._cached_dict: _T_LowLevelGraph
866-
dsk: _T_LowLevelGraph
867-
868-
if hasattr(self, "_cached_dict"):
869-
return self._cached_dict
870-
else:
871-
dsk = self._construct_graph()
872-
self._cached_dict = dsk
873-
return self._cached_dict
874-
875-
@property
876-
def _keys(self) -> Iterable[Key]:
877-
if hasattr(self, "_cached_keys"):
852+
def keys(self) -> Iterable[Key]:
853+
if self._cached_keys is not None:
878854
return self._cached_keys
879855
else:
880856
if isinstance(self.key, Iterable) and not isinstance(self.key, str):
881-
self._cached_keys: Iterable[Key] = self.key
857+
self.operands[-1] = self.key
882858
return self.key
883859

884860
else:
@@ -898,34 +874,19 @@ def _keys(self) -> Iterable[Key]:
898874
if self.iterables
899875
else []
900876
)
901-
self._cached_keys = keys
877+
self.operands[-1] = keys
902878
return keys
903879

904-
def get_output_keys(self) -> set[Key]:
905-
return set(self._keys)
906-
907-
def get_ordered_keys(self):
908-
return list(self._keys)
909-
910-
def is_materialized(self) -> bool:
911-
return hasattr(self, "_cached_dict")
912-
913-
def __getitem__(self, key: Key) -> GraphNode:
914-
return self._dict[key]
880+
def _meta(self):
881+
return []
915882

916-
def __iter__(self) -> Iterator[Key]:
917-
return iter(self._dict)
918-
919-
def __len__(self) -> int:
920-
return len(self._dict)
921-
922-
def _construct_graph(self) -> _T_LowLevelGraph:
883+
def _layer(self):
923884
dsk: _T_LowLevelGraph = {}
924885

925886
if not self.kwargs:
926887
dsk = {
927888
key: Task(key, self.func, *args)
928-
for key, args in zip(self._keys, zip(*self.iterables))
889+
for key, args in zip(self.keys, zip(*self.iterables))
929890
}
930891

931892
else:
@@ -937,12 +898,12 @@ def _construct_graph(self) -> _T_LowLevelGraph:
937898
kwargs2[k] = vv.ref()
938899
dsk[vv.key] = vv
939900
else:
940-
kwargs2[k] = v
901+
kwargs2[k] = parse_input(v)
941902

942903
dsk.update(
943904
{
944905
key: Task(key, self.func, *args, **kwargs2)
945-
for key, args in zip(self._keys, zip(*self.iterables))
906+
for key, args in zip(self.keys, zip(*self.iterables))
946907
}
947908
)
948909
return dsk
@@ -2162,16 +2123,19 @@ def submit(
21622123

21632124
if isinstance(workers, (str, Number)):
21642125
workers = [workers]
2165-
dsk = {
2166-
key: Task(
2167-
key,
2168-
func,
2169-
*(parse_input(a) for a in args),
2170-
**{k: parse_input(v) for k, v in kwargs.items()},
2171-
)
2172-
}
2126+
2127+
expr = LLGExpr(
2128+
{
2129+
key: Task(
2130+
key,
2131+
func,
2132+
*(parse_input(a) for a in args),
2133+
**{k: parse_input(v) for k, v in kwargs.items()},
2134+
)
2135+
}
2136+
)
21732137
futures = self._graph_to_futures(
2174-
dsk,
2138+
expr,
21752139
[key],
21762140
workers=workers,
21772141
allow_other_workers=allow_other_workers,
@@ -2331,14 +2295,16 @@ def map(
23312295
if allow_other_workers and workers is None:
23322296
raise ValueError("Only use allow_other_workers= if using workers=")
23332297

2334-
dsk = _MapLayer(
2298+
expr = _MapExpr(
23352299
func,
23362300
iterables,
23372301
key=key,
23382302
pure=pure,
2339-
**kwargs,
2303+
# FIXME: this doesn't look right
2304+
annotations={},
2305+
kwargs=kwargs,
23402306
)
2341-
keys = dsk.get_ordered_keys()
2307+
keys = list(expr.keys)
23422308
if isinstance(workers, (str, Number)):
23432309
workers = [workers]
23442310
if workers is not None and not isinstance(workers, (list, set)):
@@ -2347,7 +2313,7 @@ def map(
23472313
internal_priority = dict(zip(keys, range(len(keys))))
23482314

23492315
futures = self._graph_to_futures(
2350-
dsk,
2316+
expr,
23512317
keys,
23522318
workers=workers,
23532319
allow_other_workers=allow_other_workers,
@@ -2361,7 +2327,6 @@ def map(
23612327
)
23622328

23632329
# make sure the graph is not materialized
2364-
assert not dsk.is_materialized(), "Graph must be non-materialized"
23652330
logger.debug("map(%s, ...)", funcname(func))
23662331
return [futures[k] for k in keys]
23672332

@@ -3464,8 +3429,12 @@ def get(
34643429
--------
34653430
Client.compute : Compute asynchronous collections
34663431
"""
3432+
if isinstance(dsk, dict):
3433+
dsk = LLGExpr(dsk)
3434+
elif isinstance(dsk, HighLevelGraph):
3435+
dsk = HLGExpr(dsk)
34673436
futures = self._graph_to_futures(
3468-
dsk,
3437+
expr=dsk,
34693438
keys=set(flatten([keys])),
34703439
workers=workers,
34713440
allow_other_workers=allow_other_workers,
@@ -3667,7 +3636,6 @@ def compute(
36673636
expr = FinalizeCompute(expr)
36683637

36693638
expr = expr.optimize()
3670-
# FIXME: Is this actually required?
36713639
names = list(flatten(expr.__dask_keys__()))
36723640

36733641
futures_dict = self._graph_to_futures(

distributed/scheduler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9389,7 +9389,8 @@ def _materialize_graph(
93899389
annotations_by_type[annotations_type].update(
93909390
{k: (value(k) if callable(value) else value) for k in dsk}
93919391
)
9392-
annotations_by_type.update(expr.__dask_annotations__())
9392+
for annotations_type, value in expr.__dask_annotations__().items():
9393+
annotations_by_type[annotations_type].update(value)
93939394

93949395
dsk2 = convert_legacy_graph(dsk)
93959396
# FIXME: There should be no need to fully materialize and copy this but some

distributed/tests/test_client.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import dask
4343
import dask.bag as db
4444
from dask import delayed
45-
from dask.tokenize import tokenize
45+
from dask.tokenize import TokenizationError, tokenize
4646
from dask.utils import get_default_shuffle_method, parse_timedelta, tmpfile
4747

4848
from distributed import (
@@ -4639,9 +4639,9 @@ def test_recreate_error_sync(c):
46394639
y0 = c.submit(dec, 1)
46404640
x = c.submit(div, 1, x0)
46414641
y = c.submit(div, 1, y0)
4642-
tot = c.submit(sum, x, y)
4643-
f = c.compute(tot)
4644-
4642+
f = c.submit(sum, x, y)
4643+
wait(f)
4644+
assert f.status == "error"
46454645
with pytest.raises(ZeroDivisionError):
46464646
c.recreate_error_locally(f)
46474647
assert f.status == "error"
@@ -4658,8 +4658,8 @@ def test_recreate_task_sync(c):
46584658
y0 = c.submit(dec, 2)
46594659
x = c.submit(div, 1, x0)
46604660
y = c.submit(div, 1, y0)
4661-
tot = c.submit(sum, [x, y])
4662-
f = c.compute(tot)
4661+
f = c.submit(sum, [x, y])
4662+
wait(f)
46634663

46644664
assert c.recreate_task_locally(f) == 2
46654665

@@ -4814,7 +4814,17 @@ class Foo:
48144814
def __getstate__(self):
48154815
raise MyException()
48164816

4817-
with pytest.raises(TypeError, match="Could not serialize"):
4817+
with pytest.raises((TypeError, TokenizationError), match="serialize"):
4818+
future = c.submit(identity, Foo())
4819+
4820+
class Foo:
4821+
def __dask_tokenize__(self):
4822+
return 1
4823+
4824+
def __getstate__(self):
4825+
raise MyException()
4826+
4827+
with pytest.raises((TypeError, TokenizationError), match="serialize"):
48184828
future = c.submit(identity, Foo())
48194829

48204830
futures = c.map(inc, range(10))
@@ -4830,6 +4840,9 @@ class Foo:
48304840
def __getstate__(self):
48314841
return 1
48324842

4843+
def __dask_tokenize__(self):
4844+
return 1
4845+
48334846
def __setstate__(self, state):
48344847
raise MyException("hello")
48354848

@@ -4855,6 +4868,9 @@ def __getstate__(self):
48554868
def __setstate__(self, state):
48564869
raise MyException("hello")
48574870

4871+
def __dask_tokenize__(self):
4872+
return 1
4873+
48584874
def __call__(self, *args):
48594875
return 1
48604876

0 commit comments

Comments
 (0)