Skip to content

Commit 0657de2

Browse files
pynvml string/bytes compatibility (#8981)
1 parent c2ba834 commit 0657de2

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

distributed/diagnostics/nvml.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ class NVMLState(IntEnum):
3232

3333

3434
class CudaDeviceInfo(NamedTuple):
35-
uuid: bytes | None = None
35+
# Older versions of pynvml returned bytes, newer versions return str.
36+
uuid: str | bytes | None = None
3637
device_index: int | None = None
3738
mig_index: int | None = None
3839

@@ -278,13 +279,13 @@ def get_device_index_and_uuid(device):
278279
Examples
279280
--------
280281
>>> get_device_index_and_uuid(0) # doctest: +SKIP
281-
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
282+
{'device-index': 0, 'uuid': 'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
282283
283284
>>> get_device_index_and_uuid('GPU-e1006a74-5836-264f-5c26-53d19d212dfe') # doctest: +SKIP
284-
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
285+
{'device-index': 0, 'uuid': 'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
285286
286287
>>> get_device_index_and_uuid('MIG-7feb6df5-eccf-5faa-ab00-9a441867e237') # doctest: +SKIP
287-
{'device-index': 0, 'uuid': b'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
288+
{'device-index': 0, 'uuid': 'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
288289
"""
289290
init_once()
290291
try:

distributed/diagnostics/tests/test_nvml.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
pynvml = pytest.importorskip("pynvml")
1212

1313
import dask
14+
from dask.utils import ensure_unicode
1415

1516
from distributed.diagnostics import nvml
1617
from distributed.utils_test import gen_cluster
@@ -66,7 +67,7 @@ def run_has_cuda_context(queue):
6667
assert (
6768
ctx.has_context
6869
and ctx.device_info.device_index == 0
69-
and isinstance(ctx.device_info.uuid, bytes)
70+
and isinstance(ctx.device_info.uuid, str)
7071
)
7172

7273
queue.put(None)
@@ -127,7 +128,7 @@ def test_visible_devices_uuid():
127128
assert info.uuid
128129

129130
with mock.patch.dict(
130-
os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")}
131+
os.environ, {"CUDA_VISIBLE_DEVICES": ensure_unicode(info.uuid)}
131132
):
132133
h = nvml._pynvml_handles()
133134
h_expected = pynvml.nvmlDeviceGetHandleByIndex(0)
@@ -147,7 +148,7 @@ def test_visible_devices_uuid_2(index):
147148
assert info.uuid
148149

149150
with mock.patch.dict(
150-
os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")}
151+
os.environ, {"CUDA_VISIBLE_DEVICES": ensure_unicode(info.uuid)}
151152
):
152153
h = nvml._pynvml_handles()
153154
h_expected = pynvml.nvmlDeviceGetHandleByIndex(index)

0 commit comments

Comments
 (0)