From 6d05589484ebb8af1e156b84c74437eca01e363a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 13 Jan 2025 12:31:22 +0100 Subject: [PATCH] Upload files to scheduler --- distributed/diagnostics/plugin.py | 93 ++++++++++++++++++++++++------- distributed/tests/test_client.py | 47 +++++++++++++++- 2 files changed, 116 insertions(+), 24 deletions(-) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index b3f280d15ef..2c74e494e6a 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar from dask.typing import Key -from dask.utils import funcname, tmpfile +from dask.utils import _deprecated_kwarg, funcname, tmpfile from distributed.protocol.pickle import dumps @@ -896,36 +896,46 @@ async def setup(self, nanny): nanny.env.update(self.environ) -class UploadDirectory(NannyPlugin): - """A NannyPlugin to upload a local file to workers. +UPLOAD_DIRECTORY_MODES = ["all", "scheduler", "workers"] + + +class UploadDirectory(SchedulerPlugin): + """Scheduler to upload a local directory to the cluster. Parameters ---------- - path: str - A path to the directory to upload + path: + Path to the directory to upload + scheduler: + Whether to upload the directory to the scheduler Examples -------- >>> from distributed.diagnostics.plugin import UploadDirectory - >>> client.register_plugin(UploadDirectory("/path/to/directory"), nanny=True) # doctest: +SKIP + >>> client.register_plugin(UploadDirectory("/path/to/directory")) # doctest: +SKIP """ + @_deprecated_kwarg("restart", "restart_workers") def __init__( self, path, - restart=False, + restart_workers=False, update_path=False, skip_words=(".git", ".github", ".pytest_cache", "tests", "docs"), skip=(lambda fn: os.path.splitext(fn)[1] == ".pyc",), + mode="workers", ): - """ - Initialize the plugin by reading in the data from the given file. - """ path = os.path.expanduser(path) self.path = os.path.split(path)[-1] - self.restart = restart + self.restart_workers = restart_workers self.update_path = update_path + if mode not in UPLOAD_DIRECTORY_MODES: + raise ValueError( + f"{mode=} not supported, expected one of {UPLOAD_DIRECTORY_MODES}" + ) + self.mode = mode + self.name = "upload-directory-" + os.path.split(path)[-1] with tmpfile(extension="zip") as fn: @@ -944,26 +954,67 @@ def __init__( ) z.write(filename, archive_name) - with open(fn, "rb") as f: + with open(fn, mode="rb") as f: self.data = f.read() - async def setup(self, nanny): - fn = os.path.join(nanny.local_directory, f"tmp-{uuid.uuid4()}.zip") - with open(fn, "wb") as f: - f.write(self.data) + async def start(self, scheduler): + from distributed.core import clean_exception + from distributed.protocol.serialize import Serialized, deserialize + + if self.mode in ("all", "scheduler"): + _extract_data( + scheduler.local_directory, self.path, self.data, self.update_graph + ) + + if self.mode in ("all", "workers"): + nanny_plugin = _UploadDirectoryNannyPlugin( + self.path, self.data, self.restart_workers, self.update_path, self.name + ) + responses = await scheduler.register_nanny_plugin( + comm=None, + plugin=dumps(nanny_plugin), + name=self.name, + idempotent=False, + ) + + for response in responses.values(): + if response["status"] == "error": + response = { + k: deserialize(v.header, v.frames) + for k, v in response.items() + if isinstance(v, Serialized) + } + _, exc, tb = clean_exception(**response) + raise exc.with_traceback(tb) + + +class _UploadDirectoryNannyPlugin(NannyPlugin): + def __init__(self, path, data, restart, update_path, name): + self.path = path + self.data = data + self.name = name + self.restart = restart + self.update_path = update_path + + def setup(self, nanny): + _extract_data(nanny.local_directory, self.path, self.data, self.update_path) + + +def _extract_data(base_path, path, data, update_path): + with tmpfile(extension="zip") as fn: + with open(fn, mode="wb") as f: + f.write(data) import zipfile with zipfile.ZipFile(fn) as z: - z.extractall(path=nanny.local_directory) + z.extractall(path=base_path) - if self.update_path: - path = os.path.join(nanny.local_directory, self.path) + if update_path: + path = os.path.join(base_path, path) if path not in sys.path: sys.path.insert(0, path) - os.remove(fn) - class forward_stream: def __init__(self, stream, worker): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 141d9aebac3..71546cb207e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -80,7 +80,8 @@ from distributed.comm import CommClosedError from distributed.compatibility import LINUX, MACOS, WINDOWS from distributed.core import Status -from distributed.diagnostics.plugin import WorkerPlugin +from distributed.deploy.subprocess import SubprocessCluster +from distributed.diagnostics.plugin import UploadDirectory, WorkerPlugin from distributed.metrics import time from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler from distributed.shuffle import check_minimal_arrow_version @@ -7408,7 +7409,6 @@ async def test_computation_object_code_client_compute(c, s, a, b): assert comp.code[0][-1].code == test_function_code -@pytest.mark.slow @gen_cluster(client=True, Worker=Nanny) async def test_upload_directory(c, s, a, b, tmp_path): from dask.distributed import UploadDirectory @@ -7421,7 +7421,7 @@ async def test_upload_directory(c, s, a, b, tmp_path): with open(tmp_path / "bar.py", "w") as f: f.write("from foo import x") - plugin = UploadDirectory(tmp_path, restart=True, update_path=True) + plugin = UploadDirectory(tmp_path, restart_workers=True, update_path=True) await c.register_plugin(plugin) [name] = a.plugins @@ -7444,6 +7444,47 @@ def f(): assert files_start == files_end # no change +def test_upload_directory_invalid_mode(): + with pytest.raises(ValueError, match="mode"): + UploadDirectory(".", mode="invalid") + + +@pytest.mark.skipif(WINDOWS, reason="distributed#7434") +@pytest.mark.parametrize("mode", ["all", "scheduler"]) +@gen_test() +async def test_upload_directory_to_scheduler(mode, tmp_path): + from dask.distributed import UploadDirectory + + # Be sure to exclude code coverage reports + files_start = {f for f in os.listdir() if not f.startswith(".coverage")} + + with open(tmp_path / "foo.py", "w") as f: + f.write("x = 123") + with open(tmp_path / "bar.py", "w") as f: + f.write("from foo import x") + + def f(): + import bar + + return bar.x + + async with SubprocessCluster( + asynchronous=True, + dashboard_address=":0", + scheduler_kwargs={"idle_timeout": "5s"}, + worker_kwargs={"death_timeout": "5s"}, + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + with pytest.raises(ModuleNotFoundError, match="'bar'"): + res = await client.run_on_scheduler(f) + + plugin = UploadDirectory( + tmp_path, mode=mode, restart_workers=True, update_path=True + ) + await client.register_plugin(plugin) + assert await client.run_on_scheduler(f) == 123 + + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_duck_typed_register_plugin_raises(c, s, a): class DuckPlugin: