[PoC] Remote debugger based on debugpy
#14609
1 errors, 278 fail, 91 skipped, 3 705 pass in 9h 44m 18s
Annotations
Check warning on line 0 in distributed.dashboard.tests.test_scheduler_bokeh
github-actions / Unit Test Results
All 8 runs failed: test_FinePerformanceMetrics_shuffle (distributed.dashboard.tests.test_scheduler_bokeh)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 31s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 31s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 32s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 31s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 31s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 32s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
… DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_total
@property
def outgoing_current_count(self):
warnings.warn(
"The `Worker.outgoing_current_count` attribute has been renamed to "
"`Worker.transfer_outgoing_count`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count
@property
def outgoing_transfer_log(self):
warnings.warn(
"The `Worker.outgoing_transfer_log` attribute has been renamed to "
"`Worker.transfer_outgoing_log`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_log
@property
def total_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:203: in _execute_subgraph
res = execute_graph(final, keys=[outkey])
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:987: in execute_graph
cache[key] = node(cache)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_rechunk.py:181: in rechunk_unpack
return get_worker_plugin().get_output_partition(
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38563', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:38459', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:33515', name: 1, status: closed, stored: 0, running: 2/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_FinePerformanceMetrics_shuffle(c, s, a, b):
da = pytest.importorskip("dask.array")
x = da.random.random((4, 4), chunks=(1, -1))
x = x.rechunk((-1, 1), method="p2p")
x = x.sum()
> await c.compute(x)
distributed/dashboard/tests/test_scheduler_bokeh.py:462:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.deploy.tests.test_local
github-actions / Unit Test Results
All 8 runs failed: test_silent_startup (distributed.deploy.tests.test_local)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
subprocess.CalledProcessError: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/python', '-Wi', '-c', 'if 1:\n from time import sleep\n from distributed import LocalCluster\n\n if __name__ == "__main__":\n with LocalCluster(n_workers=1, dashboard_address=":0"):\n sleep(.1)\n ']' returned non-zero exit status 1.
def test_silent_startup():
code = """if 1:
from time import sleep
from distributed import LocalCluster
if __name__ == "__main__":
with LocalCluster(n_workers=1, dashboard_address=":0"):
sleep(.1)
"""
> out = subprocess.check_output(
[sys.executable, "-Wi", "-c", code], stderr=subprocess.STDOUT
)
distributed/deploy/tests/test_local.py:535:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/subprocess.py:472: in check_output
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
input = None, capture_output = False, timeout = None, check = True
popenargs = (['/home/runner/miniconda3/envs/dask-distributed/bin/python', '-Wi', '-c', 'if 1:\n from time import sleep\n ..._main__":\n with LocalCluster(n_workers=1, dashboard_address=":0"):\n sleep(.1)\n '],)
kwargs = {'stderr': -2, 'stdout': -1}
process = <Popen: returncode: 1 args: ['/home/runner/miniconda3/envs/dask-distributed/...>
stdout = b' File "<string>", line 1\n import sys; sys.path.insert(0, r\'/home/runner/miniconda3/envs/dask-distributed/lib/p... ^^\nSyntaxError: invalid syntax\n'
stderr = None, retcode = 1
def run(*popenargs,
input=None, capture_output=False, timeout=None, check=False, **kwargs):
"""Run command with arguments and return a CompletedProcess instance.
The returned instance will have attributes args, returncode, stdout and
stderr. By default, stdout and stderr are not captured, and those attributes
will be None. Pass stdout=PIPE and/or stderr=PIPE in order to capture them,
or pass capture_output=True to capture both.
If check is True and the exit code was non-zero, it raises a
CalledProcessError. The CalledProcessError object will have the return code
in the returncode attribute, and output & stderr attributes if those streams
were captured.
If timeout is given, and the process takes too long, a TimeoutExpired
exception will be raised.
There is an optional argument "input", allowing you to
pass bytes or a string to the subprocess's stdin. If you use this argument
you may not also use the Popen constructor's "stdin" argument, as
it will be used internally.
By default, all communication is in bytes, and therefore any "input" should
be bytes, and the stdout and stderr will be bytes. If in text mode, any
"input" should be a string, and stdout and stderr will be strings decoded
according to locale encoding, or by "encoding" if set. Text mode is
triggered by setting any of text, encoding, errors or universal_newlines.
The other arguments are the same as for the Popen constructor.
"""
if input is not None:
if kwargs.get('stdin') is not None:
raise ValueError('stdin and input arguments may not both be used.')
kwargs['stdin'] = PIPE
if capture_output:
if kwargs.get('stdout') is not None or kwargs.get('stderr') is not None:
raise ValueError('stdout and stderr arguments may not be used '
'with capture_output.')
kwargs['stdout'] = PIPE
kwargs['stderr'] = PIPE
with Popen(*popenargs, **kwargs) as process:
try:
stdout, stderr = process.communicate(input, timeout=timeout)
except TimeoutExpired as exc:
process.kill()
if _mswindows:
# Windows accumulates the output in a single blocking
# read() call run on child threads, with the timeout
# being done in a join() on those threads. communicate()
# _after_ kill() is required to collect that and add it
# to the exception.
exc.stdout, exc.stderr = process.communicate()
else:
# POSIX _communicate already populated the output so
# far into the TimeoutExpired exception.
process.wait()
raise
except: # Including KeyboardInterrupt, communicate handled that.
process.kill()
# We don't call process.wait() as .__exit__ does that for us.
raise
retcode = process.poll()
if check and retcode:
> raise CalledProcessError(retcode, process.args,
output=stdout, stderr=stderr)
E subprocess.CalledProcessError: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/python', '-Wi', '-c', 'if 1:\n from time import sleep\n from distributed import LocalCluster\n\n if __name__ == "__main__":\n with LocalCluster(n_workers=1, dashboard_address=":0"):\n sleep(.1)\n ']' returned non-zero exit status 1.
../../../miniconda3/envs/dask-distributed/lib/python3.13/subprocess.py:577: CalledProcessError
Check warning on line 0 in distributed.diagnostics.tests.test_eventstream
github-actions / Unit Test Results
All 8 runs failed: test_eventstream (distributed.diagnostics.tests.test_eventstream)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
assert False
+ where False = any(<generator object test_eventstream.<locals>.<genexpr> at 0x7f87406df370>)
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36721', workers: 0, cores: 0, tasks: 0>
workers = {'tcp://127.0.0.1:33231-140210778146496': 0.0, 'tcp://127.0.0.1:38845-140210815895232': 0.5, 'tcp://127.0.0.1:46607-140210797020864': 1.0}
es = <distributed.diagnostics.eventstream.EventStream object at 0x7f87407d7e00>
@py_assert1 = <generator object test_eventstream.<locals>.<genexpr> at 0x7f87406df370>
@py_assert4 = None
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
async def test_eventstream(c, s, *workers):
pytest.importorskip("bokeh")
es = EventStream()
s.add_plugin(es)
assert es.buffer == []
futures = c.map(div, [1] * 10, range(10))
total = c.submit(sum, futures[1:])
await wait(total)
await wait(futures)
assert len(es.buffer) == 11
from distributed.diagnostics.progress_stream import task_stream_append
lists = {
name: collections.deque(maxlen=100)
for name in "start duration key name color worker worker_thread y alpha".split()
}
workers = {}
for msg in es.buffer:
task_stream_append(lists, msg, workers)
assert sum(n == "transfer-sum" for n in lists["name"]) == 2
for name, color in zip(lists["name"], lists["color"]):
assert (name == "transfer-sum") == (color == "red")
> assert any(c == "black" for c in lists["color"])
E assert False
E + where False = any(<generator object test_eventstream.<locals>.<genexpr> at 0x7f87406df370>)
distributed/diagnostics/tests/test_eventstream.py:43: AssertionError
Check warning on line 0 in distributed.diagnostics.tests.test_progress_widgets
github-actions / Unit Test Results
All 8 runs failed: test_multi_progressbar_widget (distributed.diagnostics.tests.test_progress_widgets)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
assert "RuntimeError('hello!')" in '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError(\'debugpy.listen() has already been called on this process\')</tt>: 0.0s </div>'
+ where "RuntimeError('hello!')" = repr(RuntimeError('hello!'))
+ and '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError(\'debugpy.listen() has already been called on this process\')</tt>: 0.0s </div>' = HTML(value='<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError(\'debugpy.listen() has already been called on this process\')</tt>: 0.0s </div>').value
+ where HTML(value='<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError(\'debugpy.listen() has already been called on this process\')</tt>: 0.0s </div>') = <distributed.diagnostics.progressbar.MultiProgressWidget object at 0x7f8742ddef90>.elapsed_time
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43105', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:43191', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:45933', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_multi_progressbar_widget(c, s, a, b):
x1 = c.submit(inc, 1)
x2 = c.submit(inc, x1)
x3 = c.submit(inc, x2)
y1 = c.submit(dec, x3)
y2 = c.submit(dec, y1)
e = c.submit(throws, y2)
other = c.submit(inc, 123)
await wait([other, e])
p = MultiProgressWidget([e.key], scheduler=s.address, complete=True)
await p.listen()
assert p.bars["inc"].value == 1.0
assert p.bars["dec"].value == 1.0
assert p.bars["throws"].value == 0.0
assert "3 / 3" in p.bar_texts["inc"].value
assert "2 / 2" in p.bar_texts["dec"].value
assert "0 / 1" in p.bar_texts["throws"].value
assert p.bars["inc"].bar_style == "success"
assert p.bars["dec"].bar_style == "success"
assert p.bars["throws"].bar_style == "danger"
assert p.status == "error"
assert "Exception" in p.elapsed_time.value
try:
> throws(1)
distributed/diagnostics/tests/test_progress_widgets.py:69:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
x = 1
def throws(x):
> raise RuntimeError("hello!")
E RuntimeError: hello!
distributed/utils_test.py:230: RuntimeError
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43105', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:43191', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:45933', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_multi_progressbar_widget(c, s, a, b):
x1 = c.submit(inc, 1)
x2 = c.submit(inc, x1)
x3 = c.submit(inc, x2)
y1 = c.submit(dec, x3)
y2 = c.submit(dec, y1)
e = c.submit(throws, y2)
other = c.submit(inc, 123)
await wait([other, e])
p = MultiProgressWidget([e.key], scheduler=s.address, complete=True)
await p.listen()
assert p.bars["inc"].value == 1.0
assert p.bars["dec"].value == 1.0
assert p.bars["throws"].value == 0.0
assert "3 / 3" in p.bar_texts["inc"].value
assert "2 / 2" in p.bar_texts["dec"].value
assert "0 / 1" in p.bar_texts["throws"].value
assert p.bars["inc"].bar_style == "success"
assert p.bars["dec"].bar_style == "success"
assert p.bars["throws"].bar_style == "danger"
assert p.status == "error"
assert "Exception" in p.elapsed_time.value
try:
throws(1)
except Exception as e:
> assert repr(e) in p.elapsed_time.value
E assert "RuntimeError('hello!')" in '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError(\'debugpy.listen() has already been called on this process\')</tt>: 0.0s </div>'
E + where "RuntimeError('hello!')" = repr(RuntimeError('hello!'))
E + and '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError(\'debugpy.listen() has already been called on this process\')</tt>: 0.0s </div>' = HTML(value='<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError(\'debugpy.listen() has already been called on this process\')</tt>: 0.0s </div>').value
E + where HTML(value='<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError(\'debugpy.listen() has already been called on this process\')</tt>: 0.0s </div>') = <distributed.diagnostics.progressbar.MultiProgressWidget object at 0x7f8742ddef90>.elapsed_time
distributed/diagnostics/tests/test_progress_widgets.py:71: AssertionError
Check warning on line 0 in distributed.diagnostics.tests.test_progress_widgets
github-actions / Unit Test Results
All 8 runs failed: test_progressbar_done (distributed.diagnostics.tests.test_progress_widgets)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 4s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 4s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 4s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 4s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 5s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 4s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
assert "RuntimeError('hello!')" in '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError("Can\'t listen for client connections: [Errno 98] Address already in use")</tt>: 0.0s </div>'
+ where "RuntimeError('hello!')" = repr(RuntimeError('hello!'))
+ and '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError("Can\'t listen for client connections: [Errno 98] Address already in use")</tt>: 0.0s </div>' = HTML(value='<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError("Can\'t listen for client connections: [Errno 98] Address already in use")</tt>: 0.0s </div>').value
+ where HTML(value='<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError("Can\'t listen for client connections: [Errno 98] Address already in use")</tt>: 0.0s </div>') = <distributed.diagnostics.progressbar.ProgressWidget object at 0x7f874c13da70>.elapsed_time
client = <Client: 'tcp://127.0.0.1:41333' processes=2 threads=2, memory=31.23 GiB>
def test_progressbar_done(client):
L = [client.submit(inc, i) for i in range(5)]
wait(L)
p = ProgressWidget(L)
client.sync(p.listen)
assert p.status == "finished"
assert p.bar.value == 1.0
assert p.bar.bar_style == "success"
assert "Finished" in p.elapsed_time.value
f = client.submit(throws, L)
wait([f])
p = ProgressWidget([f])
client.sync(p.listen)
assert p.status == "error"
assert p.bar.value == 0.0
assert p.bar.bar_style == "danger"
assert "Exception" in p.elapsed_time.value
try:
> throws(1)
distributed/diagnostics/tests/test_progress_widgets.py:118:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
x = 1
def throws(x):
> raise RuntimeError("hello!")
E RuntimeError: hello!
distributed/utils_test.py:230: RuntimeError
During handling of the above exception, another exception occurred:
client = <Client: 'tcp://127.0.0.1:41333' processes=2 threads=2, memory=31.23 GiB>
def test_progressbar_done(client):
L = [client.submit(inc, i) for i in range(5)]
wait(L)
p = ProgressWidget(L)
client.sync(p.listen)
assert p.status == "finished"
assert p.bar.value == 1.0
assert p.bar.bar_style == "success"
assert "Finished" in p.elapsed_time.value
f = client.submit(throws, L)
wait([f])
p = ProgressWidget([f])
client.sync(p.listen)
assert p.status == "error"
assert p.bar.value == 0.0
assert p.bar.bar_style == "danger"
assert "Exception" in p.elapsed_time.value
try:
throws(1)
except Exception as e:
> assert repr(e) in p.elapsed_time.value
E assert "RuntimeError('hello!')" in '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError("Can\'t listen for client connections: [Errno 98] Address already in use")</tt>: 0.0s </div>'
E + where "RuntimeError('hello!')" = repr(RuntimeError('hello!'))
E + and '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError("Can\'t listen for client connections: [Errno 98] Address already in use")</tt>: 0.0s </div>' = HTML(value='<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError("Can\'t listen for client connections: [Errno 98] Address already in use")</tt>: 0.0s </div>').value
E + where HTML(value='<div style="padding: 0px 10px 5px 10px"><b>Exception</b> <tt>RuntimeError("Can\'t listen for client connections: [Errno 98] Address already in use")</tt>: 0.0s </div>') = <distributed.diagnostics.progressbar.ProgressWidget object at 0x7f874c13da70>.elapsed_time
distributed/diagnostics/tests/test_progress_widgets.py:120: AssertionError
Check warning on line 0 in distributed.diagnostics.tests.test_task_stream
github-actions / Unit Test Results
All 8 runs failed: test_TaskStreamPlugin (distributed.diagnostics.tests.test_task_stream)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
AssertionError: assert 10 == 11
+ where 10 = len(deque([{'key': 'div-9a97fac1bd848b99cd568e04a02df597', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2019506, 'stop': 1739560901.2020333},), ...}, {'key': 'div-9fd3e18f48985cb02e8df45d56a23299', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.203511, 'stop': 1739560901.2035694},), ...}, {'key': 'div-c0bbf95caf3549477ae5c0e2884ea717', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2051136, 'stop': 1739560901.2051601},), ...}, {'key': 'div-8a10eb4a48159ddeb78590ccf2ff82d7', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2062058, 'stop': 1739560901.2062764},), ...}, {'key': 'div-283befb1dd1613063516b5d50c66964d', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.218074, 'stop': 1739560901.2181196},), ...}, {'key': 'div-aa09f07c71718ba19d770d5c5a0e44fb', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.217989, 'stop': 1739560901.218053},), ...}, ...]))
+ where deque([{'key': 'div-9a97fac1bd848b99cd568e04a02df597', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2019506, 'stop': 1739560901.2020333},), ...}, {'key': 'div-9fd3e18f48985cb02e8df45d56a23299', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.203511, 'stop': 1739560901.2035694},), ...}, {'key': 'div-c0bbf95caf3549477ae5c0e2884ea717', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2051136, 'stop': 1739560901.2051601},), ...}, {'key': 'div-8a10eb4a48159ddeb78590ccf2ff82d7', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2062058, 'stop': 1739560901.2062764},), ...}, {'key': 'div-283befb1dd1613063516b5d50c66964d', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.218074, 'stop': 1739560901.2181196},), ...}, {'key': 'div-aa09f07c71718ba19d770d5c5a0e44fb', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.217989, 'stop': 1739560901.218053},), ...}, ...]) = <distributed.diagnostics.task_stream.TaskStreamPlugin object at 0x7f87340195a0>.buffer
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:44067', workers: 0, cores: 0, tasks: 0>
workers = (<Worker 'tcp://127.0.0.1:42623', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <W... 0>, <Worker 'tcp://127.0.0.1:34617', name: 2, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>)
es = <distributed.diagnostics.task_stream.TaskStreamPlugin object at 0x7f87340195a0>
@py_assert1 = None, @py_assert3 = None
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
async def test_TaskStreamPlugin(c, s, *workers):
es = TaskStreamPlugin(s)
s.add_plugin(es)
assert not es.buffer
futures = c.map(div, [1] * 10, range(10))
total = c.submit(sum, futures[1:])
await wait(total)
> assert len(es.buffer) == 11
E AssertionError: assert 10 == 11
E + where 10 = len(deque([{'key': 'div-9a97fac1bd848b99cd568e04a02df597', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2019506, 'stop': 1739560901.2020333},), ...}, {'key': 'div-9fd3e18f48985cb02e8df45d56a23299', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.203511, 'stop': 1739560901.2035694},), ...}, {'key': 'div-c0bbf95caf3549477ae5c0e2884ea717', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2051136, 'stop': 1739560901.2051601},), ...}, {'key': 'div-8a10eb4a48159ddeb78590ccf2ff82d7', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2062058, 'stop': 1739560901.2062764},), ...}, {'key': 'div-283befb1dd1613063516b5d50c66964d', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.218074, 'stop': 1739560901.2181196},), ...}, {'key': 'div-aa09f07c71718ba19d770d5c5a0e44fb', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.217989, 'stop': 1739560901.218053},), ...}, ...]))
E + where deque([{'key': 'div-9a97fac1bd848b99cd568e04a02df597', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2019506, 'stop': 1739560901.2020333},), ...}, {'key': 'div-9fd3e18f48985cb02e8df45d56a23299', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.203511, 'stop': 1739560901.2035694},), ...}, {'key': 'div-c0bbf95caf3549477ae5c0e2884ea717', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2051136, 'stop': 1739560901.2051601},), ...}, {'key': 'div-8a10eb4a48159ddeb78590ccf2ff82d7', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.2062058, 'stop': 1739560901.2062764},), ...}, {'key': 'div-283befb1dd1613063516b5d50c66964d', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.218074, 'stop': 1739560901.2181196},), ...}, {'key': 'div-aa09f07c71718ba19d770d5c5a0e44fb', 'metadata': {}, 'nbytes': 24, 'startstops': ({'action': 'compute', 'start': 1739560901.217989, 'stop': 1739560901.218053},), ...}, ...]) = <distributed.diagnostics.task_stream.TaskStreamPlugin object at 0x7f87340195a0>.buffer
distributed/diagnostics/tests/test_task_stream.py:26: AssertionError
Check warning on line 0 in distributed.diagnostics.tests.test_worker_plugin
github-actions / Unit Test Results
All 8 runs failed: test_failing_task_transitions_called (distributed.diagnostics.tests.test_worker_plugin)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import logging
import warnings
import pytest
from distributed import Worker, WorkerPlugin
from distributed.protocol.pickle import dumps
from distributed.utils_test import async_poll_for, captured_logger, gen_cluster, inc
class MyPlugin(WorkerPlugin):
name = "MyPlugin"
def __init__(self, data, expected_notifications=None):
self.data = data
self.expected_notifications = expected_notifications
def setup(self, worker):
assert isinstance(worker, Worker)
self.worker = worker
self.worker._my_plugin_status = "setup"
self.worker._my_plugin_data = self.data
self.observed_notifications = []
def teardown(self, worker):
self.worker._my_plugin_status = "teardown"
if self.expected_notifications is not None:
assert len(self.observed_notifications) == len(self.expected_notifications)
for expected, real in zip(
self.expected_notifications, self.observed_notifications
):
assert expected == real
def transition(self, key, start, finish, **kwargs):
self.observed_notifications.append(
{"key": key, "start": start, "finish": finish}
)
@gen_cluster(client=True, nthreads=[])
async def test_create_with_client(c, s):
await c.register_plugin(MyPlugin(123))
async with Worker(s.address) as worker:
assert worker._my_plugin_status == "setup"
assert worker._my_plugin_data == 123
assert worker._my_plugin_status == "teardown"
@gen_cluster(client=True, nthreads=[])
async def test_remove_with_client(c, s):
existing_plugins = s.worker_plugins.copy()
n_existing_plugins = len(existing_plugins)
await c.register_plugin(MyPlugin(123), name="foo")
await c.register_plugin(MyPlugin(546), name="bar")
async with Worker(s.address) as worker:
# remove the 'foo' plugin
await c.unregister_worker_plugin("foo")
assert worker._my_plugin_status == "teardown"
# check that on the scheduler registered worker plugins we only have 'bar'
assert len(s.worker_plugins) == n_existing_plugins + 1
assert "bar" in s.worker_plugins
# check on the worker plugins that we only have 'bar'
assert len(worker.plugins) == n_existing_plugins + 1
assert "bar" in worker.plugins
# let's remove 'bar' and we should have none worker plugins
await c.unregister_worker_plugin("bar")
assert worker._my_plugin_status == "teardown"
assert s.worker_plugins == existing_plugins
assert len(worker.plugins) == n_existing_plugins
@gen_cluster(client=True, nthreads=[])
async def test_remove_with_client_raises(c, s):
await c.register_plugin(MyPlugin(123), name="foo")
async with Worker(s.address):
with pytest.raises(ValueError, match="bar"):
await c.unregister_worker_plugin("bar")
@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
async def test_create_on_construction(c, s, a, b):
assert len(a.plugins) == len(b.plugins)
assert any(isinstance(plugin, MyPlugin) for plugin in a.plugins.values())
assert any(isinstance(plugin, MyPlugin) for plugin in b.plugins.values())
assert a._my_plugin_status == "setup"
assert a._my_plugin_data == 5
@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_normal_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]
plugin = MyPlugin(1, expected_notifications=expected_notifications)
await c.register_plugin(plugin)
await c.submit(lambda x: x, 1, key="task")
await async_poll_for(lambda: not w.state.tasks, timeout=10)
@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_failing_task_transitions_called(c, s, w):
class CustomError(Exception):
pass
def failing(x):
> raise CustomError()
E test_worker_plugin.test_failing_task_transitions_called.<locals>.CustomError
distributed/diagnostics/tests/test_worker_plugin.py:125: CustomError
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40997', workers: 0, cores: 0, tasks: 0>
w = <Worker 'tcp://127.0.0.1:45503', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_failing_task_transitions_called(c, s, w):
class CustomError(Exception):
pass
def failing(x):
raise CustomError()
expected_notifications = [
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
{"key": "task", "start": "error", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]
plugin = MyPlugin(1, expected_notifications=expected_notifications)
await c.register_plugin(plugin)
with pytest.raises(CustomError):
> await c.submit(failing, 1, key="task")
distributed/diagnostics/tests/test_worker_plugin.py:141:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.protocol.tests.test_highlevelgraph
github-actions / Unit Test Results
All 8 runs failed: test_combo_of_layer_types (distributed.protocol.tests.test_highlevelgraph)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
… def outgoing_current_count(self):
warnings.warn(
"The `Worker.outgoing_current_count` attribute has been renamed to "
"`Worker.transfer_outgoing_count`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count
@property
def outgoing_transfer_log(self):
warnings.warn(
"The `Worker.outgoing_transfer_log` attribute has been renamed to "
"`Worker.transfer_outgoing_log`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_log
@property
def total_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_shuffle.py:68: in shuffle_unpack
return get_worker_plugin().get_output_partition(
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35649', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:39435', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:38507', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_combo_of_layer_types(c, s, a, b):
"""Check pack/unpack of a HLG that has everything!"""
def add(x, y, z, extra_arg):
return x + y + z + extra_arg
y = c.submit(lambda x: x, 2)
z = c.submit(lambda x: x, 3)
xx = await c.submit(lambda x: x + 1, y)
x = da.blockwise(
add,
"x",
da.zeros((3,), chunks=(1,)),
"x",
da.ones((3,), chunks=(1,)),
"x",
y,
None,
concatenate=False,
dtype=int,
extra_arg=z,
)
df = dd.from_pandas(pd.DataFrame({"a": np.arange(3)}), npartitions=3)
df = df.shuffle("a")
df = df["a"].to_dask_array()
res = x.sum() + df.sum()
> res = await c.compute(res, optimize_graph=False)
distributed/protocol/tests/test_highlevelgraph.py:47:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.protocol.tests.test_highlevelgraph
github-actions / Unit Test Results
All 8 runs failed: test_shuffle (distributed.protocol.tests.test_highlevelgraph)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
… )
return self.transfer_incoming_log
@property
def outgoing_count(self):
warnings.warn(
"The `Worker.outgoing_count` attribute has been renamed to "
"`Worker.transfer_outgoing_count_total`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_total
@property
def outgoing_current_count(self):
warnings.warn(
"The `Worker.outgoing_current_count` attribute has been renamed to "
"`Worker.transfer_outgoing_count`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count
@property
def outgoing_transfer_log(self):
warnings.warn(
"The `Worker.outgoing_transfer_log` attribute has been renamed to "
"`Worker.transfer_outgoing_log`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_log
@property
def total_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_shuffle.py:68: in shuffle_unpack
return get_worker_plugin().get_output_partition(
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:34545', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:41539', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:46617', name: 1, status: closed, stored: 0, running: 2/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_shuffle(c, s, a, b):
"""Check pack/unpack of a shuffled dataframe"""
df = dd.from_pandas(
pd.DataFrame(
{"a": np.arange(10, dtype=int), "b": np.arange(10, 0, -1, dtype=float)}
),
npartitions=5,
)
df = df.shuffle("a", max_branch=2)
df = df["a"] + df["b"]
> res = await c.compute(df, optimize_graph=False)
distributed/protocol/tests/test_highlevelgraph.py:89:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.protocol.tests.test_highlevelgraph
github-actions / Unit Test Results
All 8 runs failed: test_dataframe_annotations (distributed.protocol.tests.test_highlevelgraph)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…oing_count_total
@property
def outgoing_current_count(self):
warnings.warn(
"The `Worker.outgoing_current_count` attribute has been renamed to "
"`Worker.transfer_outgoing_count`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count
@property
def outgoing_transfer_log(self):
warnings.warn(
"The `Worker.outgoing_transfer_log` attribute has been renamed to "
"`Worker.transfer_outgoing_log`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_log
@property
def total_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_shuffle.py:68: in shuffle_unpack
return get_worker_plugin().get_output_partition(
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:46151', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:43399', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:46731', name: 1, status: closed, stored: 0, running: 1/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_dataframe_annotations(c, s, a, b):
retries = 5
plugin = ExampleAnnotationPlugin(retries=retries)
s.add_plugin(plugin)
assert plugin in s.plugins.values()
df = dd.from_pandas(
pd.DataFrame(
{"a": np.arange(10, dtype=int), "b": np.arange(10, 0, -1, dtype=float)}
),
npartitions=5,
)
df = df.shuffle("a", max_branch=2)
acol = df["a"]
bcol = df["b"]
ctx = pytest.warns(
UserWarning, match="Annotations will be ignored when using query-planning"
)
with dask.annotate(retries=retries), ctx:
df = acol + bcol
with dask.config.set(optimization__fuse__active=False):
> rdf = await c.compute(df)
distributed/protocol/tests/test_highlevelgraph.py:184:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_graph
github-actions / Unit Test Results
All 8 runs failed: test_basic_state (distributed.shuffle.tests.test_graph)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
… )
return self.transfer_outgoing_log
@property
def total_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_shuffle.py:68: in shuffle_unpack
return get_worker_plugin().get_output_partition(
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:45093', workers: 0, cores: 0, tasks: 0>
workers = (<Worker 'tcp://127.0.0.1:43297', name: 0, status: closed, stored: 0, running: 1/2, ready: 0, comm: 0, waiting: 0>, <W... 0>, <Worker 'tcp://127.0.0.1:43647', name: 3, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
df = Dask DataFrame Structure:
name id x y
npartitions=12 ... ... ...
Dask Name: assign, 5 expressions
Expr=Assign(frame=ArrowStringConversion(frame=Timeseries(8830e11)))
shuffled = Dask DataFrame Structure:
name id x y
npartitions=12 ... ...
... ... ... ...
Dask Name: rearrangebycolumn, 6 expressions
Expr=Shuffle(39a0f25)
@gen_cluster([("", 2)] * 4, client=True)
async def test_basic_state(c, s, *workers):
df = dd.demo.make_timeseries(freq="15D", partition_freq="30D")
df["name"] = df["name"].astype("string[python]")
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = df.shuffle("id")
exts = [w.extensions["shuffle"] for w in workers]
for ext in exts:
assert not ext.shuffle_runs._active_runs
f = c.compute(shuffled)
# TODO this is a bad/pointless test. the `f.done()` is necessary in case the shuffle is really fast.
# To test state more thoroughly, we'd need a way to 'stop the world' at various stages. Like have the
# scheduler pause everything when the barrier is reached. Not sure yet how to implement that.
while (
not all(len(ext.shuffle_runs._active_runs) == 1 for ext in exts)
and not f.done()
):
await asyncio.sleep(0.1)
> await f
distributed/shuffle/tests/test_graph.py:96:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_graph
github-actions / Unit Test Results
All 8 runs failed: test_multiple_linear (distributed.shuffle.tests.test_graph)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 5s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 6s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 6s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 5s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 6s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 5s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: Can't listen for client connections: [Errno 98] Address already in use
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.…client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_shuffle.py:68: in shuffle_unpack
return get_worker_plugin().get_output_partition(
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
client = <Client: 'tcp://127.0.0.1:38567' processes=2 threads=2, memory=31.23 GiB>
def test_multiple_linear(client):
df = dd.demo.make_timeseries(freq="15D", partition_freq="30D")
df["name"] = df["name"].astype("string[python]")
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
s1 = df.shuffle("id")
s1["x"] = s1["x"] + 1
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
s2 = s1.shuffle("x")
with dask.config.set({"dataframe.shuffle.method": "tasks"}):
expected = df.assign(x=lambda df: df.x + 1).shuffle("x")
# TODO eventually test for fusion between s1's unpacks, the `+1`, and s2's `transfer`s
> dd.utils.assert_eq(
s2,
expected,
scheduler=client,
)
distributed/shuffle/tests/test_graph.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/dataframe/utils.py:529: in assert_eq
a = _check_dask(
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/dataframe/utils.py:419: in _check_dask
result = dsk.compute(scheduler=scheduler)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/dataframe/dask_expr/_collection.py:489: in compute
return DaskMethodsMixin.compute(out, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/base.py:374: in compute
(result,) = compute(self, traverse=False, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/base.py:662: in compute
results = schedule(dsk, keys, **kwargs)
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
raise RuntimeError("debugpy.listen() has already been called on this process")
if in_process_debug_adapter:
host, port = address
log.info("Listening: pydevd without debugpy adapter: {0}:{1}", host, port)
settrace_kwargs["patch_multiprocessing"] = False
_settrace(
host=host,
port=port,
wait_for_ready_to_run=False,
block_until_connected=False,
**settrace_kwargs
)
return
import subprocess
server_access_token = codecs.encode(os.urandom(32), "hex").decode("ascii")
try:
endpoints_listener = sockets.create_server("127.0.0.1", 0, timeout=30)
except Exception as exc:
log.swallow_exception("Can't listen for adapter endpoints:")
raise RuntimeError("can't listen for adapter endpoints: " + str(exc))
try:
endpoints_host, endpoints_port = endpoints_listener.getsockname()
log.info(
"Waiting for adapter endpoints on {0}:{1}...",
endpoints_host,
endpoints_port,
)
host, port = address
adapter_args = [
_config.get("python", sys.executable),
os.path.dirname(adapter.__file__),
"--for-server",
str(endpoints_port),
"--host",
host,
"--port",
str(port),
"--server-access-token",
server_access_token,
]
if log.log_dir is not None:
adapter_args += ["--log-dir", log.log_dir]
log.info("debugpy.listen() spawning adapter: {0}", json.repr(adapter_args))
# On Windows, detach the adapter from our console, if any, so that it doesn't
# receive Ctrl+C from it, and doesn't keep it open once we exit.
creationflags = 0
if sys.platform == "win32":
creationflags |= 0x08000000 # CREATE_NO_WINDOW
creationflags |= 0x00000200 # CREATE_NEW_PROCESS_GROUP
# On embedded applications, environment variables might not contain
# Python environment settings.
python_env = _config.get("pythonEnv")
if not bool(python_env):
python_env = None
# Adapter will outlive this process, so we shouldn't wait for it. However, we
# need to ensure that the Popen instance for it doesn't get garbage-collected
# by holding a reference to it in a non-local variable, to avoid triggering
# https://bugs.python.org/issue37380.
try:
global _adapter_process
_adapter_process = subprocess.Popen(
adapter_args,
close_fds=True,
creationflags=creationflags,
env=python_env,
)
if os.name == "posix":
# It's going to fork again to daemonize, so we need to wait on it to
# clean it up properly.
_adapter_process.wait()
else:
# Suppress misleading warning about child process still being alive when
# this process exits (https://bugs.python.org/issue38890).
_adapter_process.returncode = 0
pydevd.add_dont_terminate_child_pid(_adapter_process.pid)
except Exception as exc:
log.swallow_exception("Error spawning debug adapter:", level="info")
raise RuntimeError("error spawning debug adapter: " + str(exc))
try:
sock, _ = endpoints_listener.accept()
try:
sock.settimeout(None)
sock_io = sock.makefile("rb", 0)
try:
endpoints = json.loads(sock_io.read().decode("utf-8"))
finally:
sock_io.close()
finally:
sockets.close_socket(sock)
except socket.timeout:
log.swallow_exception(
"Timed out waiting for adapter to connect:", level="info"
)
raise RuntimeError("timed out waiting for adapter to connect")
except Exception as exc:
log.swallow_exception("Error retrieving adapter endpoints:", level="info")
raise RuntimeError("error retrieving adapter endpoints: " + str(exc))
finally:
endpoints_listener.close()
log.info("Endpoints received from adapter: {0}", json.repr(endpoints))
if "error" in endpoints:
> raise RuntimeError(str(endpoints["error"]))
E RuntimeError: Can't listen for client connections: [Errno 98] Address already in use
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:258: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[idx-inner] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…tal_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38725', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:45081', name: 0, status: closed, stored: 0, running: 1/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:35653', name: 1, status: closed, stored: 0, running: 1/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = 'idx', how = 'inner'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[idx-left] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…otal_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36839', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:37591', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:44443', name: 1, status: closed, stored: 0, running: 1/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = 'idx', how = 'left'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[idx-right] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…tal_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35469', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:39595', name: 0, status: closed, stored: 0, running: 1/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:39153', name: 1, status: closed, stored: 0, running: 2/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = 'idx', how = 'right'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[idx-outer] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…tal_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:42723', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:43615', name: 0, status: closed, stored: 0, running: 1/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:36587', name: 1, status: closed, stored: 0, running: 1/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = 'idx', how = 'outer'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[on1-inner] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…l_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:34695', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:41053', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:38821', name: 1, status: closed, stored: 0, running: 1/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = ['idx'], how = 'inner'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[on1-left] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…al_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:44849', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:34477', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:37977', name: 1, status: closed, stored: 0, running: 1/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = ['idx'], how = 'left'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[on1-right] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…l_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:42023', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:36677', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:42913', name: 1, status: closed, stored: 0, running: 2/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = ['idx'], how = 'right'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[on1-outer] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…l_in_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:34987', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:39521', name: 0, status: closed, stored: 0, running: 1/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:45961', name: 1, status: closed, stored: 0, running: 2/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = ['idx'], how = 'outer'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[on2-inner] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35747', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:36017', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:40183', name: 1, status: closed, stored: 0, running: 2/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = ['idx', 'k'], how = 'inner'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[on2-left] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…_connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36907', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:39227', name: 0, status: closed, stored: 0, running: 1/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:37361', name: 1, status: closed, stored: 0, running: 2/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = ['idx', 'k'], how = 'left'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[on2-right] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36969', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:40257', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:44973', name: 1, status: closed, stored: 0, running: 1/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = ['idx', 'k'], how = 'right'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[on2-outer] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:34075', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:44117', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:38359', name: 1, status: closed, stored: 0, running: 2/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = ['idx', 'k'], how = 'outer'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_merge_column_and_index
github-actions / Unit Test Results
All 8 runs failed: test_merge_known_to_unknown[on3-inner] (distributed.shuffle.tests.test_merge_column_and_index)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.13-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.13-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: debugpy.listen() has already been called on this process
from __future__ import annotations
import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Mapping,
MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
TextIO,
TypedDict,
TypeVar,
cast,
)
from tlz import keymap, pluck
from tornado.ioloop import IOLoop
import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
format_bytes,
funcname,
key_split,
parse_bytes,
parse_timedelta,
tmpdir,
typename,
)
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ConnectionPool,
ErrorMessage,
OKMessage,
PooledRPCCall,
Status,
coerce_to_address,
context_meter_to_server_digest,
error_message,
pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
TimeoutError,
get_ip,
has_arg,
in_async_call,
iscoroutinefunction,
json_load_robust,
log_errors,
offload,
parse_ports,
recursive_to_dict,
run_in_executor_with_context,
set_thread_state,
silence_logging_cmgr,
thread_state,
wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
WorkerDataParameter,
WorkerMemoryManager,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
DeprecatedWorkerStateAttribute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
PauseEvent,
RefreshWhoHasEvent,
RemoveReplicasEvent,
RemoveWorkerEvent,
RescheduleEvent,
RetryBusyWorkerEvent,
SecedeEvent,
StateMachineEvent,
StealRequestEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
WorkerState,
)
if TYPE_CHECKING:
# FIXME import from typing (needs Python >=3.10)
from typing_extensions import ParamSpec
# Circular imports
from distributed.client import Client
from distributed.nanny import Nanny
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_EXTENSIONS: dict[str, type] = {
"pubsub": PubSubWorkerExtension,
"spans": SpansWorkerExtension,
}
DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}
WORKER_ANY_RUNNING = {
Status.running,
Status.paused,
Status.closing_gracefully,
}
class RunTaskSuccess(OKMessage):
op: Literal["task-finished"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
class RunTaskFailure(ErrorMessage):
op: Literal["task-erred"]
result: object
nbytes: int
type: type
start: float
stop: float
thread: int
actual_exception: BaseException | Exception
class GetDataBusy(TypedDict):
status: Literal["busy"]
class GetDataSuccess(TypedDict):
status: Literal["OK"]
data: dict[Key, object]
def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to close the worker if this method encounters an exception.
"""
reason = f"worker-{method.__name__}-fail-hard"
if iscoroutinefunction(method):
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
try:
return await method(self, *args, **kwargs) # type: ignore
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
await _force_close(self, reason)
raise
else:
@wraps(method)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event("worker-fail-hard", error_message(e))
logger.exception(e)
self.loop.add_callback(_force_close, self, reason)
raise
return wrapper # type: ignore
async def _force_close(self, reason: str):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await wait_for(
self.close(nanny=False, executor_wait=False, reason=reason),
30,
)
except (KeyboardInterrupt, SystemExit): # pragma: nocover
raise
except BaseException: # pragma: nocover
# Worker is in a very broken state if closing fails. We need to shut down
# immediately, to ensure things don't get even worse and this worker potentially
# deadlocks the cluster.
from distributed import Scheduler
if Scheduler._instances:
# We're likely in a unit test. Don't kill the whole test suite!
raise
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)
class Worker(BaseWorker, ServerNode):
"""Worker node in a Dask distributed cluster
Workers perform two functions:
1. **Serve data** from a local dictionary
2. **Perform computation** on that data and on data from peers
Workers keep the scheduler informed of their data and use that scheduler to
gather data from other workers when necessary to perform a computation.
You can start a worker with the ``dask worker`` command line application::
$ dask worker scheduler-ip:port
Use the ``--help`` flag to see more options::
$ dask worker --help
The rest of this docstring is about the internal state that the worker uses
to manage and track internal computations.
**State**
**Informational State**
These attributes don't change significantly during execution.
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executors:** ``dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``PooledRPCCall``:
Location of scheduler. See ``.ip/.port`` attributes.
* **name:** ``string``:
Alias
* **services:** ``{str: Server}``:
Auxiliary web servers running on this worker
* **service_ports:** ``{str: port}``:
* **transfer_outgoing_count_limit**: ``int``
The maximum number of concurrent outgoing data transfers.
See also
:attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`.
* **batched_stream**: ``BatchedSend``
A batched stream along which we communicate to the scheduler
* **log**: ``[(message)]``
A structured and queryable log. See ``Worker.story``
**Volatile State**
These attributes track the progress of tasks that this worker is trying to
complete. In the descriptions below a ``key`` is the name of a task that
we want to compute and ``dep`` is the name of a piece of dependent data
that we want to collect from others.
* **threads**: ``{key: int}``
The ID of the thread on which the task ran
* **active_threads**: ``{int: key}``
The keys currently running on active threads
* **state**: ``WorkerState``
Encapsulated state machine. See
:class:`~distributed.worker_state_machine.BaseWorker` and
:class:`~distributed.worker_state_machine.WorkerState`
Parameters
----------
scheduler_ip: str, optional
scheduler_port: int, optional
scheduler_file: str, optional
host: str, optional
data: MutableMapping, type, None
The object to use for storage, builds a disk-backed LRU dict by default.
If a callable to construct the storage object is provided, it
will receive the worker's attr:``local_directory`` as an
argument if the calling signature has an argument named
``worker_local_directory``.
nthreads: int, optional
local_directory: str, optional
Directory where we place local resources
name: str, optional
memory_limit: int, float, string
Number of bytes of memory that this worker should use.
Set to zero for no limit. Set to 'auto' to calculate
as system.MEMORY_LIMIT * min(1, nthreads / total_cores)
Use strings or numbers like 5GB or 5e9
memory_target_fraction: float or False
Fraction of memory to try to stay beneath
(default: read from config key distributed.worker.memory.target)
memory_spill_fraction: float or False
Fraction of memory at which we start spilling to disk
(default: read from config key distributed.worker.memory.spill)
memory_pause_fraction: float or False
Fraction of memory at which we stop running new tasks
(default: read from config key distributed.worker.memory.pause)
max_spill: int, string or False
Limit of number of bytes to be spilled on disk.
(default: read from config key distributed.worker.memory.max-spill)
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload"
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Address on which to contact nanny, if it exists
lifetime: str
Amount of time like "1 hour" after which we gracefully shut down the worker.
This defaults to None, meaning no explicit shutdown time.
lifetime_stagger: str
Amount of time like "5 minutes" to stagger the lifetime value
The actual lifetime will be selected uniformly at random between
lifetime +/- lifetime_stagger
lifetime_restart: bool
Whether or not to restart a worker after it has reached its lifetime
Default False
kwargs: optional
Additional parameters to ServerNode constructor
Examples
--------
Use the command line to start a worker::
$ dask scheduler
Start scheduler at 127.0.0.1:8786
$ dask worker 127.0.0.1:8786
Start worker at: 127.0.0.1:1234
Registered with scheduler at: 127.0.0.1:8786
See Also
--------
distributed.scheduler.Scheduler
distributed.nanny.Nanny
"""
_instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet()
_initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()
nanny: Nanny | None
_lock: threading.Lock
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
profile_recent: dict[str, Any]
profile_history: deque[tuple[float, dict[str, Any]]]
transfer_incoming_log: deque[dict[str, Any]]
transfer_outgoing_log: deque[dict[str, Any]]
#: Total number of data transfers to other workers since the worker was started
transfer_outgoing_count_total: int
#: Total size of data transfers to other workers (including in-progress and failed transfers)
transfer_outgoing_bytes_total: int
#: Current total size of open data transfers to other workers
transfer_outgoing_bytes: int
#: Current number of open data transfers to other workers
transfer_outgoing_count: int
bandwidth: float
latency: float
profile_cycle_interval: float
workspace: WorkSpace
_client: Client | None
bandwidth_workers: defaultdict[str, tuple[float, int]]
bandwidth_types: defaultdict[type, tuple[float, int]]
preloads: preloading.PreloadManager
contact_address: str | None
_start_port: int | str | Collection[int] | None = None
_start_host: str | None
_interface: str | None
_protocol: str
_dashboard_address: str | None
_dashboard: bool
_http_prefix: str
death_timeout: float | None
lifetime: float | None
lifetime_stagger: float | None
lifetime_restart: bool
extensions: dict
security: Security
connection_args: dict[str, Any]
loop: IOLoop
executors: dict[str, Executor]
batched_stream: BatchedSend
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
services: dict[str, Any] = {}
service_specs: dict[str, Any]
metrics: dict[str, Callable[[Worker], Any]]
startup_information: dict[str, Callable[[Worker], Any]]
low_level_profiler: bool
scheduler: PooledRPCCall
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
def __init__(
self,
scheduler_ip: str | None = None,
scheduler_port: int | None = None,
*,
scheduler_file: str | None = None,
nthreads: int | None = None,
loop: IOLoop | None = None, # Deprecated
local_directory: str | None = None,
services: dict | None = None,
name: Any | None = None,
reconnect: bool | None = None,
executor: Executor | dict[str, Executor] | Literal["offload"] | None = None,
resources: dict[str, float] | None = None,
silence_logs: int | None = None,
death_timeout: Any | None = None,
preload: list[str] | None = None,
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
extensions: dict[str, type] | None = None,
metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS,
startup_information: Mapping[
str, Callable[[Worker], Any]
] = DEFAULT_STARTUP_INFORMATION,
interface: str | None = None,
host: str | None = None,
port: int | str | Collection[int] | None = None,
protocol: str | None = None,
dashboard_address: str | None = None,
dashboard: bool = False,
http_prefix: str = "/",
nanny: Nanny | None = None,
plugins: tuple[WorkerPlugin, ...] = (),
low_level_profiler: bool | None = None,
validate: bool | None = None,
profile_cycle_interval=None,
lifetime: Any | None = None,
lifetime_stagger: Any | None = None,
lifetime_restart: bool | None = None,
transition_counter_max: int | Literal[False] = False,
###################################
# Parameters to WorkerMemoryManager
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
data: WorkerDataParameter = None,
# Deprecated parameters; please use dask config instead.
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
memory_pause_fraction: float | Literal[False] | None = None,
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
if reconnect:
raise ValueError(
"The `reconnect=True` option for `Worker` has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
else:
warnings.warn(
"The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
DeprecationWarning,
stacklevel=2,
)
if loop is not None:
warnings.warn(
"The `loop` argument to `Worker` is ignored, and will be removed in a future release. "
"The Worker always binds to the current loop",
DeprecationWarning,
stacklevel=2,
)
self.__exit_stack = stack = contextlib.ExitStack()
self.nanny = nanny
self._lock = threading.Lock()
transfer_incoming_count_limit = dask.config.get(
"distributed.worker.connections.outgoing"
)
self.transfer_outgoing_count_limit = dask.config.get(
"distributed.worker.connections.incoming"
)
transfer_message_bytes_limit = parse_bytes(
dask.config.get("distributed.worker.transfer.message-bytes-limit")
)
self.threads = {}
self.active_threads_lock = threading.Lock()
self.active_threads = {}
self.active_keys = set()
self.profile_keys = defaultdict(profile.create)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.profile_keys_history = deque(maxlen=maxlen)
self.profile_history = deque(maxlen=maxlen)
self.profile_recent = profile.create()
if validate is None:
validate = dask.config.get("distributed.worker.validate")
self.transfer_incoming_log = deque(maxlen=maxlen)
self.transfer_outgoing_log = deque(maxlen=maxlen)
self.transfer_outgoing_count_total = 0
self.transfer_outgoing_bytes_total = 0
self.transfer_outgoing_bytes = 0
self.transfer_outgoing_count = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.bandwidth_workers = defaultdict(
lambda: (0, 0)
) # bw/count recent transfers
self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers
self.latency = 0.001
self._client = None
if profile_cycle_interval is None:
profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle")
profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms")
assert profile_cycle_interval
self._setup_logging(logger)
self.death_timeout = parse_timedelta(death_timeout)
self.contact_address = contact_address
self._start_port = port
self._start_host = host
if host:
# Helpful error message if IPv6 specified incorrectly
_, host_address = parse_address(host)
if host_address.count(":") > 1 and not host_address.startswith("["):
raise ValueError(
"Host address with IPv6 must be bracketed like '[::1]'; "
f"got {host_address}"
)
self._interface = interface
nthreads = nthreads or CPU_COUNT
if resources is None:
resources = dask.config.get("distributed.worker.resources")
assert isinstance(resources, dict)
self.extensions = {}
if silence_logs:
stack.enter_context(silence_logging_cmgr(level=silence_logs))
if isinstance(security, dict):
security = Security(**security)
self.security = security or Security()
assert isinstance(self.security, Security)
self.connection_args = self.security.get_connection_args("worker")
self.loop = self.io_loop = IOLoop.current()
if scheduler_sni:
self.connection_args["server_hostname"] = scheduler_sni
# Common executors always available
self.executors = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}
# Find the default executor
if executor == "offload":
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
nthreads, thread_name_prefix="Dask-Default-Threads"
)
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
self.plugins = {}
self._pending_plugins = plugins
self.services = {}
self.service_specs = services or {}
self._dashboard_address = dashboard_address
self._dashboard = dashboard
self._http_prefix = http_prefix
self.metrics = dict(metrics) if metrics else {}
self.startup_information = (
dict(startup_information) if startup_information else {}
)
if low_level_profiler is None:
low_level_profiler = dask.config.get("distributed.worker.profile.low-level")
self.low_level_profiler = low_level_profiler
handlers = {
"gather": self.gather,
"run": self.run,
"run_coroutine": self.run_coroutine,
"get_data": self.get_data,
"update_data": self.update_data,
"free_keys": self._handle_remote_stimulus(FreeKeysEvent),
"terminate": self.close,
"ping": pingpong,
"upload_file": self.upload_file,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
"profile_metadata": self.get_profile_metadata,
"get_logs": self.get_logs,
"keys": self.keys,
"versions": self.versions,
"actor_execute": self.actor_execute,
"actor_attribute": self.actor_attribute,
"plugin-add": self.plugin_add,
"plugin-remove": self.plugin_remove,
"get_monitor_info": self.get_monitor_info,
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
"get_story": self.get_story,
}
stream_handlers = {
"close": self.close,
"cancel-compute": self._handle_remote_stimulus(CancelComputeEvent),
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
"worker-status-change": self.handle_worker_status_change,
"remove-worker": self._handle_remove_worker,
}
ServerNode.__init__(
self,
handlers=handlers,
stream_handlers=stream_handlers,
connection_args=self.connection_args,
local_directory=local_directory,
**kwargs,
)
if not preload:
preload = dask.config.get("distributed.worker.preload")
if not preload_argv:
preload_argv = dask.config.get("distributed.worker.preload-argv")
assert preload is not None
assert preload_argv is not None
self.preloads = preloading.process_preloads(
self, preload, preload_argv, file_dir=self.local_directory
)
if scheduler_file:
cfg = json_load_robust(scheduler_file, timeout=self.death_timeout)
scheduler_addr = cfg["address"]
elif scheduler_ip is None and dask.config.get("scheduler-address", None):
scheduler_addr = dask.config.get("scheduler-address")
elif scheduler_port is None:
scheduler_addr = coerce_to_address(scheduler_ip)
else:
scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port))
if protocol is None:
protocol_address = scheduler_addr.split("://")
if len(protocol_address) == 2:
protocol = protocol_address[0]
assert protocol
self._protocol = protocol
self.memory_manager = WorkerMemoryManager(
self,
data=data,
nthreads=nthreads,
memory_limit=memory_limit,
memory_target_fraction=memory_target_fraction,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
…connections(self):
warnings.warn(
"The `Worker.total_in_connections` attribute has been renamed to "
"`Worker.transfer_outgoing_count_limit`",
DeprecationWarning,
stacklevel=2,
)
return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
--------
>>> def f():
... worker = get_worker() # The worker on which this task is running
... return worker.address
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
'tcp://127.0.0.1:47373'
See Also
--------
get_client
worker_client
"""
try:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Parameters
----------
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Returns
-------
Client
Examples
--------
>>> def f():
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
>>> future = client.submit(f) # doctest: +SKIP
>>> future.result() # doctest: +SKIP
55
See Also
--------
get_worker
worker_client
secede
"""
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, "s")
if address and resolve_address:
address = comm_resolve_address(address)
try:
worker = get_worker()
except ValueError: # could not find worker
pass
else:
if not address or worker.scheduler.address == address:
return worker._get_client(timeout=timeout)
from distributed.client import Client
try:
client = Client.current() # TODO: assumes the same scheduler
except ValueError:
client = None
if client and (not address or client.scheduler.address == address):
return client
elif address:
return Client(address, timeout=timeout)
else:
raise ValueError("No global client found and no address provided")
def secede():
"""
Have this task secede from the worker's thread pool
This opens up a new scheduling slot and a new thread for a new task. This
enables the client to schedule tasks on this node, which is
especially useful while waiting for other jobs to finish (e.g., with
``client.gather``).
Examples
--------
>>> def mytask(x):
... # do some work
... client = get_client()
... futures = client.map(...) # do some remote work
... secede() # while that work happens, remove ourself from the pool
... return client.gather(futures) # return gathered results
See Also
--------
get_client
get_worker
"""
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"secede-{time()}",
),
)
async def get_data_from_worker(
rpc: ConnectionPool,
keys: Collection[Key],
worker: str,
*,
who: str | None = None,
serializers: list[str] | None = None,
deserializers: list[str] | None = None,
) -> GetDataBusy | GetDataSuccess:
"""Get keys from worker
The worker has a two step handshake to acknowledge when data has been fully
delivered. This function implements that handshake.
See Also
--------
Worker.get_data
Worker.gather_dep
utils_comm.gather_data_from_workers
"""
if serializers is None:
serializers = rpc.serializers
if deserializers is None:
deserializers = rpc.deserializers
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
)
try:
status = response["status"]
except KeyError: # pragma: no cover
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
return response
finally:
rpc.reuse(worker, comm)
cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100)
_cache_lock = threading.Lock()
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func)
return result
def _run_task(
task: GraphNode,
data: dict,
execution_state: dict,
key: Key,
active_threads: dict,
active_threads_lock: threading.Lock,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
ident = threading.get_ident()
with active_threads_lock:
active_threads[ident] = key
with set_thread_state(
start_time=time(),
execution_state=execution_state,
key=key,
):
token = _worker_cvar.set(execution_state["worker"])
try:
msg = _run_task_simple(task, data, time_delay)
finally:
_worker_cvar.reset(token)
with active_threads_lock:
del active_threads[ident]
return msg
def _run_task_simple(
task: GraphNode,
data: dict,
time_delay: float,
) -> RunTaskSuccess | RunTaskFailure:
"""Run a function, collect information
Returns
-------
msg: dictionary with status, result/error, timings, etc..
"""
# meter("thread-cpu").delta
# Difference in thread_time() before and after function call, minus user calls
# to context_meter inside the function. Published to Server.digests as
# {("execute", <prefix>, "thread-cpu", "seconds"): <value>}
# m.delta
# Difference in wall time before and after function call, minus thread-cpu,
# minus user calls to context_meter. Published to Server.digests as
# {("execute", <prefix>, "thread-noncpu", "seconds"): <value>}
# m.stop - m.start
# Difference in wall time before and after function call, without subtracting
# anything. This is used in scheduler heuristics, e.g. task stealing.
with (
context_meter.meter("thread-noncpu", func=time) as m,
context_meter.meter("thread-cpu", func=thread_time),
):
try:
> result = task(data)
distributed/worker.py:2982:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/dask/_task_spec.py:741: in __call__
return self.func(*new_argspec)
distributed/shuffle/_merge.py:53: in merge_unpack
left = ext.get_output_partition(shuffle_id_left, barrier_left, output_partition)
distributed/shuffle/_worker_plugin.py:432: in get_output_partition
return shuffle_run.get_output_partition(
distributed/shuffle/_core.py:381: in get_output_partition
sync(self._loop, self._ensure_output_worker, partition_id, key)
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/tornado/gen.py:766: in run
value = future.result()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
> raise Reschedule()
E distributed.exceptions.Reschedule
distributed/shuffle/_core.py:344: Reschedule
During handling of the above exception, another exception occurred:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:37137', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:34299', name: 0, status: closed, stored: 0, running: 1/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:35357', name: 1, status: closed, stored: 0, running: 1/2, ready: 0, comm: 0, waiting: 0>
df_left = k v1
idx
0 0 0
0 1 1
0 2 2
1 0 3
1 1 4
1 2 5
1 3 6
2 0 7
2 1 ... 0 37
9 1 38
9 2 39
9 3 40
9 4 41
9 5 42
9 6 43
10 0 44
10 1 45
10 2 46
10 3 47
df_right = k v1
idx
0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
2 0 6
2 1 7
2 2 ... 1 42
9 2 43
9 3 44
10 0 45
10 1 46
10 2 47
10 3 48
10 4 49
10 5 50
10 6 51
10 7 52
ddf_left = Dask DataFrame Structure:
k v1
npartitions=10
0 int64 int64
1 ... ... ...
10 ... ...
11 ... ...
Dask Name: from_pd_divs, 1 expression
Expr=df
ddf_right_unknown = Dask DataFrame Structure:
k v1
npartitions=10
int64 int64
... ... ...
... ...
Dask Name: return_input, 2 expressions
Expr=ClearDivisions(frame=df)
on = ['k', 'idx'], how = 'inner'
@gen_cluster(client=True)
async def test_merge_known_to_unknown(
c,
s,
a,
b,
df_left,
df_right,
ddf_left,
ddf_right_unknown,
on,
how,
):
# Compute expected
expected = df_left.merge(df_right, on=on, how=how)
# Perform merge
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result_graph = ddf_left.merge(ddf_right_unknown, on=on, how=how)
> result = await c.compute(result_graph)
distributed/shuffle/tests/test_merge_column_and_index.py:129:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:409: in _result
raise exc.with_traceback(tb)
distributed/utils.py:1507: in run_in_executor_with_context
return await loop.run_in_executor(
distributed/_concurrent_futures_thread.py:65: in run
result = self.fn(*self.args, **self.kwargs)
distributed/utils.py:1508: in <lambda>
executor, lambda: context.run(func, *args, **kwargs)
distributed/worker.py:2946: in _run_task
msg = _run_task_simple(task, data, time_delay)
distributed/worker.py:2998: in _run_task_simple
debugpy.post_mortem()
distributed/debugpy.py:60: in post_mortem
host, port = _ensure_debugpy_listens()
distributed/debugpy.py:36: in _ensure_debugpy_listens
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/public_api.py:31: in wrapper
return wrapped(*args, **kwargs)
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:132: in debug
log.reraise_exception("{0}() failed:", func.__name__, level="info")
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:130: in debug
return func(address, settrace_kwargs, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.
import codecs
import os
import pydevd
import socket
import sys
import threading
import debugpy
from debugpy import adapter
from debugpy.common import json, log, sockets
from _pydevd_bundle.pydevd_constants import get_global_debugger
from pydevd_file_utils import absolute_path
from debugpy.common.util import hide_debugpy_internals
_tls = threading.local()
# TODO: "gevent", if possible.
_config = {
"qt": "none",
"subProcess": True,
"python": sys.executable,
"pythonEnv": {},
}
_config_valid_values = {
# If property is not listed here, any value is considered valid, so long as
# its type matches that of the default value in _config.
"qt": ["auto", "none", "pyside", "pyside2", "pyqt4", "pyqt5"],
}
# This must be a global to prevent it from being garbage collected and triggering
# https://bugs.python.org/issue37380.
_adapter_process = None
def _settrace(*args, **kwargs):
log.debug("pydevd.settrace(*{0!r}, **{1!r})", args, kwargs)
# The stdin in notification is not acted upon in debugpy, so, disable it.
kwargs.setdefault("notify_stdin", False)
try:
pydevd.settrace(*args, **kwargs)
except Exception:
raise
def ensure_logging():
"""Starts logging to log.log_dir, if it hasn't already been done."""
if ensure_logging.ensured:
return
ensure_logging.ensured = True
log.to_file(prefix="debugpy.server")
log.describe_environment("Initial environment:")
if log.log_dir is not None:
pydevd.log_to(log.log_dir + "/debugpy.pydevd.log")
ensure_logging.ensured = False
def log_to(path):
if ensure_logging.ensured:
raise RuntimeError("logging has already begun")
log.debug("log_to{0!r}", (path,))
if path is sys.stderr:
log.stderr.levels |= set(log.LEVELS)
else:
log.log_dir = path
def configure(properties=None, **kwargs):
ensure_logging()
log.debug("configure{0!r}", (properties, kwargs))
if properties is None:
properties = kwargs
else:
properties = dict(properties)
properties.update(kwargs)
for k, v in properties.items():
if k not in _config:
raise ValueError("Unknown property {0!r}".format(k))
expected_type = type(_config[k])
if type(v) is not expected_type:
raise ValueError("{0!r} must be a {1}".format(k, expected_type.__name__))
valid_values = _config_valid_values.get(k)
if (valid_values is not None) and (v not in valid_values):
raise ValueError("{0!r} must be one of: {1!r}".format(k, valid_values))
_config[k] = v
def _starts_debugging(func):
def debug(address, **kwargs):
try:
_, port = address
except Exception:
port = address
address = ("127.0.0.1", port)
try:
port.__index__() # ensure it's int-like
except Exception:
raise ValueError("expected port or (host, port)")
if not (0 <= port < 2**16):
raise ValueError("invalid port number")
ensure_logging()
log.debug("{0}({1!r}, **{2!r})", func.__name__, address, kwargs)
log.info("Initial debug configuration: {0}", json.repr(_config))
qt_mode = _config.get("qt", "none")
if qt_mode != "none":
pydevd.enable_qt_support(qt_mode)
settrace_kwargs = {
"suspend": False,
"patch_multiprocessing": _config.get("subProcess", True),
}
if hide_debugpy_internals():
debugpy_path = os.path.dirname(absolute_path(debugpy.__file__))
settrace_kwargs["dont_trace_start_patterns"] = (debugpy_path,)
settrace_kwargs["dont_trace_end_patterns"] = (str("debugpy_launcher.py"),)
try:
return func(address, settrace_kwargs, **kwargs)
except Exception:
log.reraise_exception("{0}() failed:", func.__name__, level="info")
return debug
@_starts_debugging
def listen(address, settrace_kwargs, in_process_debug_adapter=False):
# Errors below are logged with level="info", because the caller might be catching
# and handling exceptions, and we don't want to spam their stderr unnecessarily.
if listen.called:
# Multiple calls to listen() cause the debuggee to hang
> raise RuntimeError("debugpy.listen() has already been called on this process")
E RuntimeError: debugpy.listen() has already been called on this process
../../../miniconda3/envs/dask-distributed/lib/python3.13/site-packages/debugpy/server/api.py:144: RuntimeError