Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a990d13

Browse files
authoredJan 13, 2025··
Merge branch 'main' into upload-directory-uploads-to-scheduler
2 parents 6d05589 + bcdbabe commit a990d13

File tree

11 files changed

+32
-120
lines changed

11 files changed

+32
-120
lines changed
 

‎distributed/client.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from dask.core import flatten, validate_key
4646
from dask.highlevelgraph import HighLevelGraph
4747
from dask.layers import Layer
48-
from dask.optimization import SubgraphCallable
4948
from dask.tokenize import tokenize
5049
from dask.typing import Key, NestedKeys, NoDefault, no_default
5150
from dask.utils import (
@@ -1147,7 +1146,7 @@ def __init__(
11471146
if security is None and isinstance(address, str):
11481147
security = _maybe_call_security_loader(address)
11491148

1150-
if security is None:
1149+
if security is None or security is False:
11511150
security = Security()
11521151
elif isinstance(security, dict):
11531152
security = Security(**security)
@@ -6120,8 +6119,6 @@ def futures_of(o, client=None):
61206119
stack.extend(x)
61216120
elif type(x) is dict:
61226121
stack.extend(x.values())
6123-
elif type(x) is SubgraphCallable:
6124-
stack.extend(x.dsk.values())
61256122
elif isinstance(x, TaskRef):
61266123
if x not in seen:
61276124
seen.add(x)

‎distributed/comm/tcp.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,18 @@ def get_stream_address(comm):
129129

130130
def convert_stream_closed_error(obj, exc):
131131
"""
132-
Re-raise StreamClosedError as CommClosedError.
132+
Re-raise StreamClosedError or SSLError as CommClosedError.
133133
"""
134-
if exc.real_error is not None:
134+
if hasattr(exc, "real_error"):
135135
# The stream was closed because of an underlying OS error
136+
if exc.real_error is None:
137+
raise CommClosedError(f"in {obj}: {exc}") from exc
136138
exc = exc.real_error
137-
if isinstance(exc, ssl.SSLError):
138-
if exc.reason and "UNKNOWN_CA" in exc.reason:
139-
raise FatalCommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}")
140-
raise CommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}") from exc
141-
else:
142-
raise CommClosedError(f"in {obj}: {exc}") from exc
139+
140+
if isinstance(exc, ssl.SSLError):
141+
if exc.reason and "UNKNOWN_CA" in exc.reason:
142+
raise FatalCommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}")
143+
raise CommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}") from exc
143144

144145

145146
def _close_comm(ref):
@@ -230,7 +231,7 @@ async def read(self, deserializers=None):
230231
buffer = await read_bytes_rw(stream, buffer_nbytes)
231232
frames.append(buffer)
232233

233-
except StreamClosedError as e:
234+
except (StreamClosedError, SSLError) as e:
234235
self.stream = None
235236
self._closed = True
236237
convert_stream_closed_error(self, e)

‎distributed/diagnostics/nvml.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ class NVMLState(IntEnum):
3232

3333

3434
class CudaDeviceInfo(NamedTuple):
35-
uuid: bytes | None = None
35+
# Older versions of pynvml returned bytes, newer versions return str.
36+
uuid: str | bytes | None = None
3637
device_index: int | None = None
3738
mig_index: int | None = None
3839

@@ -278,13 +279,13 @@ def get_device_index_and_uuid(device):
278279
Examples
279280
--------
280281
>>> get_device_index_and_uuid(0) # doctest: +SKIP
281-
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
282+
{'device-index': 0, 'uuid': 'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
282283
283284
>>> get_device_index_and_uuid('GPU-e1006a74-5836-264f-5c26-53d19d212dfe') # doctest: +SKIP
284-
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
285+
{'device-index': 0, 'uuid': 'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
285286
286287
>>> get_device_index_and_uuid('MIG-7feb6df5-eccf-5faa-ab00-9a441867e237') # doctest: +SKIP
287-
{'device-index': 0, 'uuid': b'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
288+
{'device-index': 0, 'uuid': 'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
288289
"""
289290
init_once()
290291
try:

‎distributed/diagnostics/tests/test_nvml.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
pynvml = pytest.importorskip("pynvml")
1212

1313
import dask
14+
from dask.utils import ensure_unicode
1415

1516
from distributed.diagnostics import nvml
1617
from distributed.utils_test import gen_cluster
@@ -66,7 +67,7 @@ def run_has_cuda_context(queue):
6667
assert (
6768
ctx.has_context
6869
and ctx.device_info.device_index == 0
69-
and isinstance(ctx.device_info.uuid, bytes)
70+
and isinstance(ctx.device_info.uuid, str)
7071
)
7172

7273
queue.put(None)
@@ -127,7 +128,7 @@ def test_visible_devices_uuid():
127128
assert info.uuid
128129

129130
with mock.patch.dict(
130-
os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")}
131+
os.environ, {"CUDA_VISIBLE_DEVICES": ensure_unicode(info.uuid)}
131132
):
132133
h = nvml._pynvml_handles()
133134
h_expected = pynvml.nvmlDeviceGetHandleByIndex(0)
@@ -147,7 +148,7 @@ def test_visible_devices_uuid_2(index):
147148
assert info.uuid
148149

149150
with mock.patch.dict(
150-
os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")}
151+
os.environ, {"CUDA_VISIBLE_DEVICES": ensure_unicode(info.uuid)}
151152
):
152153
h = nvml._pynvml_handles()
153154
h_expected = pynvml.nvmlDeviceGetHandleByIndex(index)

‎distributed/protocol/tests/test_cupy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_serialize_cupy_from_rmm(size):
7474
)
7575
@pytest.mark.parametrize(
7676
"dtype",
77-
[numpy.dtype("<f4"), numpy.dtype(">f4"), numpy.dtype("<f8"), numpy.dtype(">f8")],
77+
[numpy.dtype("<f4"), numpy.dtype("<f8")],
7878
)
7979
@pytest.mark.parametrize("serializer", ["cuda", "dask", "pickle"])
8080
def test_serialize_cupy_sparse(sparse_name, dtype, serializer):

‎distributed/protocol/tests/test_scipy.py

-12
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,7 @@
3030
"dtype",
3131
[
3232
numpy.dtype("<f4"),
33-
pytest.param(
34-
numpy.dtype(">f4"),
35-
marks=pytest.mark.skipif(
36-
SCIPY_GE_1_15_0, reason="https://github.com/scipy/scipy/issues/22258"
37-
),
38-
),
3933
numpy.dtype("<f8"),
40-
pytest.param(
41-
numpy.dtype(">f8"),
42-
marks=pytest.mark.skipif(
43-
SCIPY_GE_1_15_0, reason="https://github.com/scipy/scipy/issues/22258"
44-
),
45-
),
4634
],
4735
)
4836
def test_serialize_scipy_sparse(sparse_type, dtype):

‎distributed/tests/test_client.py

-45
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
import dask
4343
import dask.bag as db
4444
from dask import delayed
45-
from dask.optimization import SubgraphCallable
4645
from dask.tokenize import tokenize
4746
from dask.utils import get_default_shuffle_method, parse_timedelta, tmpfile
4847

@@ -2627,13 +2626,6 @@ async def test_futures_of_get(c, s, a, b):
26272626
b = db.Bag({("b", i): f for i, f in enumerate([x, y, z])}, "b", 3)
26282627
assert set(futures_of(b)) == {x, y, z}
26292628

2630-
sg = SubgraphCallable(
2631-
{"x": x, "y": y, "z": z, "out": (add, (add, (add, x, y), z), "in")},
2632-
"out",
2633-
("in",),
2634-
)
2635-
assert set(futures_of(sg)) == {x, y, z}
2636-
26372629

26382630
def test_futures_of_class():
26392631
pytest.importorskip("numpy")
@@ -6192,43 +6184,6 @@ async def test_profile_bokeh(c, s, a, b):
61926184
assert os.path.exists(fn)
61936185

61946186

6195-
@gen_cluster(client=True, nthreads=[("", 1)])
6196-
async def test_get_mix_futures_and_SubgraphCallable(c, s, a):
6197-
future = c.submit(add, 1, 2)
6198-
6199-
subgraph = SubgraphCallable(
6200-
{"_2": (add, "_0", "_1"), "_3": (add, future, "_2")},
6201-
"_3",
6202-
("_0", "_1"),
6203-
)
6204-
dsk = {
6205-
"a": 1,
6206-
"b": 2,
6207-
"c": (subgraph, "a", "b"),
6208-
"d": (subgraph, "c", "b"),
6209-
}
6210-
6211-
future2 = c.get(dsk, "d", sync=False)
6212-
result = await future2
6213-
assert result == 11
6214-
6215-
# Nested subgraphs
6216-
subgraph2 = SubgraphCallable(
6217-
{
6218-
"_2": (subgraph, "_0", "_1"),
6219-
"_3": (subgraph, "_2", "_1"),
6220-
"_4": (add, "_3", future2),
6221-
},
6222-
"_4",
6223-
("_0", "_1"),
6224-
)
6225-
6226-
dsk2 = {"e": 1, "f": 2, "g": (subgraph2, "e", "f")}
6227-
6228-
result = await c.get(dsk2, "g", sync=False)
6229-
assert result == 22
6230-
6231-
62326187
@gen_cluster(client=True)
62336188
async def test_get_mix_futures_and_SubgraphCallable_dask_dataframe(c, s, a, b):
62346189
pd = pytest.importorskip("pandas")

‎distributed/tests/test_tls_functional.py

+10
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,16 @@ async def test_security_dict_input_no_security():
209209
assert result == 2
210210

211211

212+
@gen_test()
213+
async def test_security_bool_input_disabled_security():
214+
async with Scheduler(dashboard_address=":0", security=False) as s:
215+
async with Worker(s.address, security=False):
216+
async with Client(s.address, security=False, asynchronous=True) as c:
217+
result = await c.submit(inc, 1)
218+
assert c.security.require_encryption is False
219+
assert result == 2
220+
221+
212222
@gen_test()
213223
async def test_security_dict_input():
214224
conf = tls_config()

‎distributed/tests/test_utils_comm.py

-16
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pytest
88

99
from dask._task_spec import TaskRef
10-
from dask.optimization import SubgraphCallable
1110

1211
from distributed import wait
1312
from distributed.compatibility import asyncio_run
@@ -246,18 +245,3 @@ def assert_eq(keys1: set[TaskRef], keys2: set[TaskRef]) -> None:
246245
res, keys = unpack_remotedata(TaskRef("mykey"))
247246
assert res == "mykey"
248247
assert_eq(keys, {TaskRef("mykey")})
249-
250-
# Check unpack of SC that contains a wrapped key
251-
sc = SubgraphCallable({"key": (TaskRef("data"),)}, outkey="key", inkeys=["arg1"])
252-
dsk = (sc, "arg1")
253-
res, keys = unpack_remotedata(dsk)
254-
assert res[0] != sc # Notice, the first item (the SC) has been changed
255-
assert res[1:] == ("arg1", "data")
256-
assert_eq(keys, {TaskRef("data")})
257-
258-
# Check unpack of SC when it takes a wrapped key as argument
259-
sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[TaskRef("arg1")])
260-
dsk = (sc, "arg1")
261-
res, keys = unpack_remotedata(dsk)
262-
assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed
263-
assert_eq(keys, set())

‎distributed/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,7 @@ def command_has_keyword(cmd, k):
12831283

12841284
@toolz.memoize
12851285
def color_of(x, palette=palette):
1286-
h = md5(str(x).encode())
1286+
h = md5(str(x).encode(), usedforsecurity=False)
12871287
n = int(h.hexdigest()[:8], 16)
12881288
return palette[n % len(palette)]
12891289

‎distributed/utils_comm.py

-25
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import dask.config
1515
from dask._task_spec import TaskRef
16-
from dask.optimization import SubgraphCallable
1716
from dask.typing import Key
1817
from dask.utils import is_namedtuple_instance, parse_timedelta
1918

@@ -197,30 +196,6 @@ def _unpack_remotedata_inner(
197196
if typ is tuple:
198197
if not o:
199198
return o
200-
if type(o[0]) is SubgraphCallable:
201-
# Unpack futures within the arguments of the subgraph callable
202-
futures: set[TaskRef] = set()
203-
args = tuple(_unpack_remotedata_inner(i, byte_keys, futures) for i in o[1:])
204-
found_futures.update(futures)
205-
206-
# Unpack futures within the subgraph callable itself
207-
sc: SubgraphCallable = o[0]
208-
futures = set()
209-
dsk = {
210-
k: _unpack_remotedata_inner(v, byte_keys, futures)
211-
for k, v in sc.dsk.items()
212-
}
213-
future_keys: tuple = ()
214-
if futures: # If no futures is in the subgraph, we just use `sc` as-is
215-
found_futures.update(futures)
216-
future_keys = (
217-
tuple(f.key for f in futures)
218-
if byte_keys
219-
else tuple(f.key for f in futures)
220-
)
221-
inkeys = tuple(sc.inkeys) + future_keys
222-
sc = SubgraphCallable(dsk, sc.outkey, inkeys, sc.name)
223-
return (sc,) + args + future_keys
224199
else:
225200
return tuple(
226201
_unpack_remotedata_inner(item, byte_keys, found_futures) for item in o

0 commit comments

Comments
 (0)
Please sign in to comment.