39
39
40
40
from packaging .version import parse as parse_version
41
41
from tlz import first , groupby , merge , partition_all , valmap
42
+ from tornado import gen
43
+ from tornado .ioloop import IOLoop
42
44
43
45
import dask
46
+ from dask ._expr import Expr , HLGExpr , LLGExpr
47
+ from dask ._task_spec import DataNode , GraphNode , List , Task , TaskRef , parse_input
44
48
from dask .base import collections_to_dsk
45
49
from dask .core import flatten , validate_key
46
- from dask .layers import Layer
50
+ from dask .highlevelgraph import HighLevelGraph
47
51
from dask .tokenize import tokenize
48
52
from dask .typing import Key , NestedKeys , NoDefault , no_default
49
53
from dask .utils import (
57
61
)
58
62
from dask .widgets import get_template
59
63
60
- from distributed .core import OKMessage
61
- from distributed .protocol .serialize import _is_dumpable
62
- from distributed .utils import Deadline , wait_for
63
-
64
- try :
65
- from dask .delayed import single_key
66
- except ImportError :
67
- single_key = first
68
- from tornado import gen
69
- from tornado .ioloop import IOLoop
70
-
71
- from dask ._task_spec import DataNode , GraphNode , List , Task , TaskRef , parse_input
72
-
73
64
import distributed .utils
74
65
from distributed import cluster_dump , preloading
75
66
from distributed import versions as version_module
79
70
from distributed .core import (
80
71
CommClosedError ,
81
72
ConnectionPool ,
73
+ OKMessage ,
82
74
PooledRPCCall ,
83
75
Status ,
84
76
clean_exception ,
98
90
from distributed .objects import HasWhat , SchedulerInfo , WhoHas
99
91
from distributed .protocol import to_serialize
100
92
from distributed .protocol .pickle import dumps , loads
93
+ from distributed .protocol .serialize import _is_dumpable
101
94
from distributed .publish import Datasets
102
95
from distributed .pubsub import PubSubClientExtension
103
96
from distributed .security import Security
106
99
from distributed .threadpoolexecutor import rejoin
107
100
from distributed .utils import (
108
101
CancelledError ,
102
+ Deadline ,
109
103
LoopRunner ,
110
104
NoOpAwaitable ,
111
105
SyncMethodMixin ,
117
111
nbytes ,
118
112
sync ,
119
113
thread_state ,
114
+ wait_for ,
120
115
)
121
116
from distributed .utils_comm import (
122
117
gather_from_workers ,
@@ -834,51 +829,32 @@ def _is_nested(iterable):
834
829
return False
835
830
836
831
837
- class _MapLayer ( Layer ):
832
+ class _MapExpr ( Expr ):
838
833
func : Callable
839
- iterables : Iterable [ Any ]
840
- key : str | Iterable [ str ] | None
834
+ iterables : Iterable
835
+ key : Key
841
836
pure : bool
842
- annotations : dict [str , Any ] | None
843
-
844
- def __init__ (
845
- self ,
846
- func : Callable ,
847
- iterables : Iterable [Any ],
848
- key : str | Iterable [str ] | None = None ,
849
- pure : bool = True ,
850
- annotations : dict [str , Any ] | None = None ,
851
- ** kwargs ,
852
- ):
853
- self .func : Callable = func
854
- self .iterables = [tuple (map (parse_input , iterable )) for iterable in iterables ]
855
- self .key : str | Iterable [str ] | None = key
856
- self .pure : bool = pure
857
- self .kwargs = {k : parse_input (v ) for k , v in kwargs .items ()}
858
- super ().__init__ (annotations = annotations )
859
-
860
- def __repr__ (self ) -> str :
861
- return f"{ type (self ).__name__ } <func='{ funcname (self .func )} '>"
837
+ annotations : dict
838
+ kwargs : dict
839
+ _cached_keys : Iterable [Key ] | None
840
+ _parameters = [
841
+ "func" ,
842
+ "iterables" ,
843
+ "key" ,
844
+ "pure" ,
845
+ "annotations" ,
846
+ "kwargs" ,
847
+ "_cached_keys" ,
848
+ ]
849
+ _defaults = {"_cached_keys" : None }
862
850
863
851
@property
864
- def _dict (self ) -> _T_LowLevelGraph :
865
- self ._cached_dict : _T_LowLevelGraph
866
- dsk : _T_LowLevelGraph
867
-
868
- if hasattr (self , "_cached_dict" ):
869
- return self ._cached_dict
870
- else :
871
- dsk = self ._construct_graph ()
872
- self ._cached_dict = dsk
873
- return self ._cached_dict
874
-
875
- @property
876
- def _keys (self ) -> Iterable [Key ]:
877
- if hasattr (self , "_cached_keys" ):
852
+ def keys (self ) -> Iterable [Key ]:
853
+ if self ._cached_keys is not None :
878
854
return self ._cached_keys
879
855
else :
880
856
if isinstance (self .key , Iterable ) and not isinstance (self .key , str ):
881
- self ._cached_keys : Iterable [ Key ] = self .key
857
+ self .operands [ - 1 ] = self .key
882
858
return self .key
883
859
884
860
else :
@@ -898,34 +874,19 @@ def _keys(self) -> Iterable[Key]:
898
874
if self .iterables
899
875
else []
900
876
)
901
- self ._cached_keys = keys
877
+ self .operands [ - 1 ] = keys
902
878
return keys
903
879
904
- def get_output_keys (self ) -> set [Key ]:
905
- return set (self ._keys )
906
-
907
- def get_ordered_keys (self ):
908
- return list (self ._keys )
909
-
910
- def is_materialized (self ) -> bool :
911
- return hasattr (self , "_cached_dict" )
912
-
913
- def __getitem__ (self , key : Key ) -> GraphNode :
914
- return self ._dict [key ]
880
+ def _meta (self ):
881
+ return []
915
882
916
- def __iter__ (self ) -> Iterator [Key ]:
917
- return iter (self ._dict )
918
-
919
- def __len__ (self ) -> int :
920
- return len (self ._dict )
921
-
922
- def _construct_graph (self ) -> _T_LowLevelGraph :
883
+ def _layer (self ):
923
884
dsk : _T_LowLevelGraph = {}
924
885
925
886
if not self .kwargs :
926
887
dsk = {
927
888
key : Task (key , self .func , * args )
928
- for key , args in zip (self ._keys , zip (* self .iterables ))
889
+ for key , args in zip (self .keys , zip (* self .iterables ))
929
890
}
930
891
931
892
else :
@@ -937,12 +898,12 @@ def _construct_graph(self) -> _T_LowLevelGraph:
937
898
kwargs2 [k ] = vv .ref ()
938
899
dsk [vv .key ] = vv
939
900
else :
940
- kwargs2 [k ] = v
901
+ kwargs2 [k ] = parse_input ( v )
941
902
942
903
dsk .update (
943
904
{
944
905
key : Task (key , self .func , * args , ** kwargs2 )
945
- for key , args in zip (self ._keys , zip (* self .iterables ))
906
+ for key , args in zip (self .keys , zip (* self .iterables ))
946
907
}
947
908
)
948
909
return dsk
@@ -2162,16 +2123,19 @@ def submit(
2162
2123
2163
2124
if isinstance (workers , (str , Number )):
2164
2125
workers = [workers ]
2165
- dsk = {
2166
- key : Task (
2167
- key ,
2168
- func ,
2169
- * (parse_input (a ) for a in args ),
2170
- ** {k : parse_input (v ) for k , v in kwargs .items ()},
2171
- )
2172
- }
2126
+
2127
+ expr = LLGExpr (
2128
+ {
2129
+ key : Task (
2130
+ key ,
2131
+ func ,
2132
+ * (parse_input (a ) for a in args ),
2133
+ ** {k : parse_input (v ) for k , v in kwargs .items ()},
2134
+ )
2135
+ }
2136
+ )
2173
2137
futures = self ._graph_to_futures (
2174
- dsk ,
2138
+ expr ,
2175
2139
[key ],
2176
2140
workers = workers ,
2177
2141
allow_other_workers = allow_other_workers ,
@@ -2331,14 +2295,16 @@ def map(
2331
2295
if allow_other_workers and workers is None :
2332
2296
raise ValueError ("Only use allow_other_workers= if using workers=" )
2333
2297
2334
- dsk = _MapLayer (
2298
+ expr = _MapExpr (
2335
2299
func ,
2336
2300
iterables ,
2337
2301
key = key ,
2338
2302
pure = pure ,
2339
- ** kwargs ,
2303
+ # FIXME: this doesn't look right
2304
+ annotations = {},
2305
+ kwargs = kwargs ,
2340
2306
)
2341
- keys = dsk . get_ordered_keys ( )
2307
+ keys = list ( expr . keys )
2342
2308
if isinstance (workers , (str , Number )):
2343
2309
workers = [workers ]
2344
2310
if workers is not None and not isinstance (workers , (list , set )):
@@ -2347,7 +2313,7 @@ def map(
2347
2313
internal_priority = dict (zip (keys , range (len (keys ))))
2348
2314
2349
2315
futures = self ._graph_to_futures (
2350
- dsk ,
2316
+ expr ,
2351
2317
keys ,
2352
2318
workers = workers ,
2353
2319
allow_other_workers = allow_other_workers ,
@@ -2361,7 +2327,6 @@ def map(
2361
2327
)
2362
2328
2363
2329
# make sure the graph is not materialized
2364
- assert not dsk .is_materialized (), "Graph must be non-materialized"
2365
2330
logger .debug ("map(%s, ...)" , funcname (func ))
2366
2331
return [futures [k ] for k in keys ]
2367
2332
@@ -3464,8 +3429,12 @@ def get(
3464
3429
--------
3465
3430
Client.compute : Compute asynchronous collections
3466
3431
"""
3432
+ if isinstance (dsk , dict ):
3433
+ dsk = LLGExpr (dsk )
3434
+ elif isinstance (dsk , HighLevelGraph ):
3435
+ dsk = HLGExpr (dsk )
3467
3436
futures = self ._graph_to_futures (
3468
- dsk ,
3437
+ expr = dsk ,
3469
3438
keys = set (flatten ([keys ])),
3470
3439
workers = workers ,
3471
3440
allow_other_workers = allow_other_workers ,
@@ -3667,7 +3636,6 @@ def compute(
3667
3636
expr = FinalizeCompute (expr )
3668
3637
3669
3638
expr = expr .optimize ()
3670
- # FIXME: Is this actually required?
3671
3639
names = list (flatten (expr .__dask_keys__ ()))
3672
3640
3673
3641
futures_dict = self ._graph_to_futures (
0 commit comments