Skip to content

Commit 09cec1f

Browse files
Enable UploadDirectory plugin to upload to scheduler (#8986)
1 parent bcdbabe commit 09cec1f

File tree

2 files changed

+116
-24
lines changed

2 files changed

+116
-24
lines changed

distributed/diagnostics/plugin.py

+72-21
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import TYPE_CHECKING, Any, Callable, ClassVar
1616

1717
from dask.typing import Key
18-
from dask.utils import funcname, tmpfile
18+
from dask.utils import _deprecated_kwarg, funcname, tmpfile
1919

2020
from distributed.protocol.pickle import dumps
2121

@@ -896,36 +896,46 @@ async def setup(self, nanny):
896896
nanny.env.update(self.environ)
897897

898898

899-
class UploadDirectory(NannyPlugin):
900-
"""A NannyPlugin to upload a local file to workers.
899+
UPLOAD_DIRECTORY_MODES = ["all", "scheduler", "workers"]
900+
901+
902+
class UploadDirectory(SchedulerPlugin):
903+
"""Scheduler to upload a local directory to the cluster.
901904
902905
Parameters
903906
----------
904-
path: str
905-
A path to the directory to upload
907+
path:
908+
Path to the directory to upload
909+
scheduler:
910+
Whether to upload the directory to the scheduler
906911
907912
Examples
908913
--------
909914
>>> from distributed.diagnostics.plugin import UploadDirectory
910-
>>> client.register_plugin(UploadDirectory("/path/to/directory"), nanny=True) # doctest: +SKIP
915+
>>> client.register_plugin(UploadDirectory("/path/to/directory")) # doctest: +SKIP
911916
"""
912917

918+
@_deprecated_kwarg("restart", "restart_workers")
913919
def __init__(
914920
self,
915921
path,
916-
restart=False,
922+
restart_workers=False,
917923
update_path=False,
918924
skip_words=(".git", ".github", ".pytest_cache", "tests", "docs"),
919925
skip=(lambda fn: os.path.splitext(fn)[1] == ".pyc",),
926+
mode="workers",
920927
):
921-
"""
922-
Initialize the plugin by reading in the data from the given file.
923-
"""
924928
path = os.path.expanduser(path)
925929
self.path = os.path.split(path)[-1]
926-
self.restart = restart
930+
self.restart_workers = restart_workers
927931
self.update_path = update_path
928932

933+
if mode not in UPLOAD_DIRECTORY_MODES:
934+
raise ValueError(
935+
f"{mode=} not supported, expected one of {UPLOAD_DIRECTORY_MODES}"
936+
)
937+
self.mode = mode
938+
929939
self.name = "upload-directory-" + os.path.split(path)[-1]
930940

931941
with tmpfile(extension="zip") as fn:
@@ -944,26 +954,67 @@ def __init__(
944954
)
945955
z.write(filename, archive_name)
946956

947-
with open(fn, "rb") as f:
957+
with open(fn, mode="rb") as f:
948958
self.data = f.read()
949959

950-
async def setup(self, nanny):
951-
fn = os.path.join(nanny.local_directory, f"tmp-{uuid.uuid4()}.zip")
952-
with open(fn, "wb") as f:
953-
f.write(self.data)
960+
async def start(self, scheduler):
961+
from distributed.core import clean_exception
962+
from distributed.protocol.serialize import Serialized, deserialize
963+
964+
if self.mode in ("all", "scheduler"):
965+
_extract_data(
966+
scheduler.local_directory, self.path, self.data, self.update_graph
967+
)
968+
969+
if self.mode in ("all", "workers"):
970+
nanny_plugin = _UploadDirectoryNannyPlugin(
971+
self.path, self.data, self.restart_workers, self.update_path, self.name
972+
)
973+
responses = await scheduler.register_nanny_plugin(
974+
comm=None,
975+
plugin=dumps(nanny_plugin),
976+
name=self.name,
977+
idempotent=False,
978+
)
979+
980+
for response in responses.values():
981+
if response["status"] == "error":
982+
response = {
983+
k: deserialize(v.header, v.frames)
984+
for k, v in response.items()
985+
if isinstance(v, Serialized)
986+
}
987+
_, exc, tb = clean_exception(**response)
988+
raise exc.with_traceback(tb)
989+
990+
991+
class _UploadDirectoryNannyPlugin(NannyPlugin):
992+
def __init__(self, path, data, restart, update_path, name):
993+
self.path = path
994+
self.data = data
995+
self.name = name
996+
self.restart = restart
997+
self.update_path = update_path
998+
999+
def setup(self, nanny):
1000+
_extract_data(nanny.local_directory, self.path, self.data, self.update_path)
1001+
1002+
1003+
def _extract_data(base_path, path, data, update_path):
1004+
with tmpfile(extension="zip") as fn:
1005+
with open(fn, mode="wb") as f:
1006+
f.write(data)
9541007

9551008
import zipfile
9561009

9571010
with zipfile.ZipFile(fn) as z:
958-
z.extractall(path=nanny.local_directory)
1011+
z.extractall(path=base_path)
9591012

960-
if self.update_path:
961-
path = os.path.join(nanny.local_directory, self.path)
1013+
if update_path:
1014+
path = os.path.join(base_path, path)
9621015
if path not in sys.path:
9631016
sys.path.insert(0, path)
9641017

965-
os.remove(fn)
966-
9671018

9681019
class forward_stream:
9691020
def __init__(self, stream, worker):

distributed/tests/test_client.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@
7979
from distributed.comm import CommClosedError
8080
from distributed.compatibility import LINUX, MACOS, WINDOWS
8181
from distributed.core import Status
82-
from distributed.diagnostics.plugin import WorkerPlugin
82+
from distributed.deploy.subprocess import SubprocessCluster
83+
from distributed.diagnostics.plugin import UploadDirectory, WorkerPlugin
8384
from distributed.metrics import time
8485
from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler
8586
from distributed.shuffle import check_minimal_arrow_version
@@ -7363,7 +7364,6 @@ async def test_computation_object_code_client_compute(c, s, a, b):
73637364
assert comp.code[0][-1].code == test_function_code
73647365

73657366

7366-
@pytest.mark.slow
73677367
@gen_cluster(client=True, Worker=Nanny)
73687368
async def test_upload_directory(c, s, a, b, tmp_path):
73697369
from dask.distributed import UploadDirectory
@@ -7376,7 +7376,7 @@ async def test_upload_directory(c, s, a, b, tmp_path):
73767376
with open(tmp_path / "bar.py", "w") as f:
73777377
f.write("from foo import x")
73787378

7379-
plugin = UploadDirectory(tmp_path, restart=True, update_path=True)
7379+
plugin = UploadDirectory(tmp_path, restart_workers=True, update_path=True)
73807380
await c.register_plugin(plugin)
73817381

73827382
[name] = a.plugins
@@ -7399,6 +7399,47 @@ def f():
73997399
assert files_start == files_end # no change
74007400

74017401

7402+
def test_upload_directory_invalid_mode():
7403+
with pytest.raises(ValueError, match="mode"):
7404+
UploadDirectory(".", mode="invalid")
7405+
7406+
7407+
@pytest.mark.skipif(WINDOWS, reason="distributed#7434")
7408+
@pytest.mark.parametrize("mode", ["all", "scheduler"])
7409+
@gen_test()
7410+
async def test_upload_directory_to_scheduler(mode, tmp_path):
7411+
from dask.distributed import UploadDirectory
7412+
7413+
# Be sure to exclude code coverage reports
7414+
files_start = {f for f in os.listdir() if not f.startswith(".coverage")}
7415+
7416+
with open(tmp_path / "foo.py", "w") as f:
7417+
f.write("x = 123")
7418+
with open(tmp_path / "bar.py", "w") as f:
7419+
f.write("from foo import x")
7420+
7421+
def f():
7422+
import bar
7423+
7424+
return bar.x
7425+
7426+
async with SubprocessCluster(
7427+
asynchronous=True,
7428+
dashboard_address=":0",
7429+
scheduler_kwargs={"idle_timeout": "5s"},
7430+
worker_kwargs={"death_timeout": "5s"},
7431+
) as cluster:
7432+
async with Client(cluster, asynchronous=True) as client:
7433+
with pytest.raises(ModuleNotFoundError, match="'bar'"):
7434+
res = await client.run_on_scheduler(f)
7435+
7436+
plugin = UploadDirectory(
7437+
tmp_path, mode=mode, restart_workers=True, update_path=True
7438+
)
7439+
await client.register_plugin(plugin)
7440+
assert await client.run_on_scheduler(f) == 123
7441+
7442+
74027443
@gen_cluster(client=True, nthreads=[("", 1)])
74037444
async def test_duck_typed_register_plugin_raises(c, s, a):
74047445
class DuckPlugin:

0 commit comments

Comments
 (0)