Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keyword to allow disabling config forwarding in SSHCluster #8994

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions distributed/deploy/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__( # type: ignore[no-untyped-def]
worker_module="deprecated",
worker_class="distributed.Nanny",
remote_python=None,
forward_config=True,
loop=None,
name=None,
):
Expand All @@ -92,6 +93,7 @@ def __init__( # type: ignore[no-untyped-def]
self.kwargs = copy.copy(kwargs)
self.name = name
self.remote_python = remote_python
self.forward_config = forward_config
if kwargs.get("nprocs") is not None and kwargs.get("n_workers") is not None:
raise ValueError(
"Both nprocs and n_workers were specified. Use n_workers only."
Expand Down Expand Up @@ -135,21 +137,24 @@ async def start(self):

self.connection = await asyncssh.connect(self.address, **self.connect_options)

result = await self.connection.run("uname")
if result.exit_status == 0:
set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format(
dask.config.serialize(dask.config.global_config)
)
else:
result = await self.connection.run("cmd /c ver")
if self.forward_config:
result = await self.connection.run("uname")
if result.exit_status == 0:
set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format(
set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format(
dask.config.serialize(dask.config.global_config)
)
else:
raise Exception(
"Worker failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
)
result = await self.connection.run("cmd /c ver")
if result.exit_status == 0:
set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format(
dask.config.serialize(dask.config.global_config)
)
else:
raise Exception(
"Worker failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
)
else:
set_env = ""

if not self.remote_python:
self.remote_python = sys.executable
Expand All @@ -175,7 +180,7 @@ async def start(self):
}
),
]
)
).strip()

self.proc = await self.connection.create_process(cmd)

Expand Down Expand Up @@ -214,13 +219,15 @@ def __init__(
connect_options: dict,
kwargs: dict,
remote_python: str | None = None,
forward_config: bool = True,
):
super().__init__()

self.address = address
self.kwargs = kwargs
self.connect_options = connect_options
self.remote_python = remote_python or sys.executable
self.forward_config = forward_config

async def start(self):
try:
Expand All @@ -235,21 +242,24 @@ async def start(self):

self.connection = await asyncssh.connect(self.address, **self.connect_options)

result = await self.connection.run("uname")
if result.exit_status == 0:
set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format(
dask.config.serialize(dask.config.global_config)
)
else:
result = await self.connection.run("cmd /c ver")
if self.forward_config:
result = await self.connection.run("uname")
if result.exit_status == 0:
set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format(
set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format(
dask.config.serialize(dask.config.global_config)
)
else:
raise Exception(
"Scheduler failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
)
result = await self.connection.run("cmd /c ver")
if result.exit_status == 0:
set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format(
dask.config.serialize(dask.config.global_config)
)
else:
raise Exception(
"Scheduler failed to set DASK_INTERNAL_INHERIT_CONFIG variable "
)
else:
set_env = ""

cmd = " ".join(
[
Expand All @@ -260,7 +270,7 @@ async def start(self):
"--spec",
"'%s'" % dumps({"cls": "distributed.Scheduler", "opts": self.kwargs}),
]
)
).strip()
self.proc = await self.connection.create_process(cmd)

# We watch stderr in order to get the address, then we return
Expand Down Expand Up @@ -304,6 +314,7 @@ def SSHCluster(
worker_module: str = "deprecated",
worker_class: str = "distributed.Nanny",
remote_python: str | list[str] | None = None,
forward_config: bool = True,
**kwargs: Any,
) -> SpecCluster:
"""Deploy a Dask cluster using SSH
Expand Down Expand Up @@ -344,6 +355,8 @@ def SSHCluster(
The python class to use to create the worker(s).
remote_python
Path to Python on remote nodes.
forward_config
Forward the local Dask configuration to the remote nodes.

Examples
--------
Expand Down Expand Up @@ -443,6 +456,7 @@ def SSHCluster(
"remote_python": (
remote_python[0] if isinstance(remote_python, list) else remote_python
),
"forward_config": forward_config,
},
}
workers = {
Expand All @@ -462,6 +476,7 @@ def SSHCluster(
if isinstance(remote_python, list)
else remote_python
),
"forward_config": forward_config,
},
}
for i, host in enumerate(hosts[1:])
Expand Down
Loading