From 40e5c88910567b65a5d82c03dba9870491cdd53b Mon Sep 17 00:00:00 2001 From: wietzesuijker Date: Tue, 20 Feb 2024 14:01:43 +0000 Subject: [PATCH] feat: MambaInstall plugin feat: MambaInstall plugin feat: MambaInstall plugin feat: MambaInstall plugin feat: MambaInstall plugin --- distributed/diagnostics/plugin.py | 72 +++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 3a866facc36..057a90f68b8 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -634,6 +634,78 @@ def __call__(self) -> None: raise RuntimeError(msg) +class MambaInstall(InstallPlugin): + """A plugin to install a set of packages with Mamba or Micromamba. + + This accepts a set of packages to install on the scheduler and all workers, + as well as options to use when installing. You can specify the use of Micromamba + and whether to restart the workers after installation. + + Parameters + ---------- + packages : list[str] + A list of packages (with optional versions) to install using Mamba or Micromamba. + channels : list[str], optional + A list of channels to include for package resolution. + mamba_options : list[str], optional + Additional command-line options to pass to Mamba. + use_micromamba : bool, optional + Whether to use Micromamba instead of Mamba for installation. Defaults to False. + restart_workers : bool, optional + Whether to restart the worker after installing the packages. Defaults to False. + + Examples: + -------- + >>> from dask.distributed import MambaInstall + >>> plugin = MambaInstall(packages=["numpy"], channels=["conda-forge"], mamba_options=["--strict-channel-priority"]) + + >>> client.register_plugin(plugin) + """ + + def __init__( + self, + packages: list[str], + channels: list[str] | None = None, + mamba_options: list[str] | None = None, + use_micromamba: bool = False, + restart_workers: bool = False, + ): + installer = _MambaInstaller(packages, channels, mamba_options, use_micromamba) + super().__init__(installer, restart_workers=restart_workers) + + +class _MambaInstaller: + def __init__( + self, + packages: list[str], + channels: list[str] | None, + mamba_options: list[str] | None, + use_micromamba: bool, + ): + self.packages = packages + self.channels = channels or ["conda-forge"] + self.mamba_options = mamba_options or [] + self.use_micromamba = use_micromamba + + def __call__(self) -> None: + installer = "micromamba" if self.use_micromamba else "mamba" + logger.info( + "%s installing the following packages: %s", + installer, + ", ".join(self.packages), + ) + channels_str = " ".join([f"-c {channel}" for channel in self.channels]) + options_str = " ".join(self.mamba_options) + packages_str = " ".join(self.packages) + command = f"{installer} install -y {channels_str} {options_str} {packages_str}" + try: + subprocess.run(command, shell=True, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + msg = f"{installer} install failed with '{e.stderr.decode().strip()}'" + logger.error(msg) + raise RuntimeError(msg) + + class PipInstall(InstallPlugin): """A plugin to pip install a set of packages