Skip to content

Commit 02854ea

Browse files
committed
feat: MambaInstall plugin
feat: MambaInstall plugin feat: MambaInstall plugin
1 parent b03efee commit 02854ea

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

distributed/diagnostics/plugin.py

+72
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,78 @@ def __call__(self) -> None:
634634
raise RuntimeError(msg)
635635

636636

637+
class MambaInstall(InstallPlugin):
638+
"""A plugin to install a set of packages with Mamba or Micromamba.
639+
640+
This accepts a set of packages to install on the scheduler and all workers,
641+
as well as options to use when installing. You can specify the use of Micromamba
642+
and whether to restart the workers after installation.
643+
644+
Parameters
645+
----------
646+
packages : list[str]
647+
A list of packages (with optional versions) to install using Mamba or Micromamba.
648+
channels : list[str], optional
649+
A list of channels to include for package resolution.
650+
mamba_options : list[str], optional
651+
Additional command-line options to pass to Mamba.
652+
use_micromamba : bool, optional
653+
Whether to use Micromamba instead of Mamba for installation. Defaults to False.
654+
restart_workers : bool, optional
655+
Whether to restart the worker after installing the packages. Defaults to False.
656+
657+
Examples:
658+
--------
659+
>>> from dask.distributed import MambaInstall
660+
>>> plugin = MambaInstall(packages=["numpy"], channels=["conda-forge"], mamba_options=["--strict-channel-priority"])
661+
662+
>>> client.register_plugin(plugin)
663+
"""
664+
665+
def __init__(
666+
self,
667+
packages: list[str],
668+
channels: list[str] | None = None,
669+
mamba_options: list[str] | None = None,
670+
use_micromamba: bool = False,
671+
restart_workers: bool = False,
672+
):
673+
installer = _MambaInstaller(packages, channels, mamba_options, use_micromamba)
674+
super().__init__(installer, restart_workers=restart_workers)
675+
676+
677+
class _MambaInstaller:
678+
def __init__(
679+
self,
680+
packages: list[str],
681+
channels: list[str] | None,
682+
mamba_options: list[str] | None,
683+
use_micromamba: bool
684+
):
685+
self.packages = packages
686+
self.channels = channels or ["conda-forge"]
687+
self.mamba_options = mamba_options or []
688+
self.use_micromamba = use_micromamba
689+
690+
def __call__(self) -> None:
691+
installer = "micromamba" if self.use_micromamba else "mamba"
692+
logger.info(
693+
"%s installing the following packages: %s",
694+
installer,
695+
", ".join(self.packages),
696+
)
697+
channels_str = " ".join([f"-c {channel}" for channel in self.channels])
698+
options_str = " ".join(self.mamba_options)
699+
packages_str = " ".join(self.packages)
700+
command = f"{installer} install -y {channels_str} {options_str} {packages_str}"
701+
try:
702+
subprocess.run(command, shell=True, check=True, capture_output=True)
703+
except subprocess.CalledProcessError as e:
704+
msg = f"{installer} install failed with '{e.stderr.decode().strip()}'"
705+
logger.error(msg)
706+
raise RuntimeError(msg)
707+
708+
637709
class PipInstall(InstallPlugin):
638710
"""A plugin to pip install a set of packages
639711

0 commit comments

Comments
 (0)