diff --git a/distributed/debugpy.py b/distributed/debugpy.py new file mode 100644 index 0000000000..615d6de0e9 --- /dev/null +++ b/distributed/debugpy.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import importlib.util +import logging +import sys +import threading + +import dask.config + +logger = logging.getLogger(__name__) + +DEBUGPY_ENABLED: bool = dask.config.get("distributed.diagnostics.debugpy.enabled") +DEBUGPY_PORT: int = dask.config.get("distributed.diagnostics.debugpy.port") + + +def _check_debugpy_installed(): + if importlib.util.find_spec("debugpy") is None: + raise ModuleNotFoundError( + "Dask debugger requires debugpy. Please make sure it is installed." + ) + + +LOCK = threading.Lock() + + +def _ensure_debugpy_listens() -> tuple[str, int]: + import debugpy + + from distributed.worker import get_worker + + worker = get_worker() + + with LOCK: + if endpoint := worker.extensions.get("debugpy", None): + return endpoint + endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT)) + worker.extensions["debugpy"] = endpoint + return endpoint + + +def breakpointhook() -> None: + import debugpy + + host, port = _ensure_debugpy_listens() + if not debugpy.is_client_connected(): + logger.warning( + "Breakpoint encountered; waiting for client to attach to %s:%d...", + host, + port, + ) + debugpy.wait_for_client() + + debugpy.breakpoint() + + +def post_mortem() -> None: + # Based on https://github.com/microsoft/debugpy/issues/723 + import debugpy + + host, port = _ensure_debugpy_listens() + if not debugpy.is_client_connected(): + logger.warning( + "Exception encountered; waiting for client to attach to %s:%d...", + host, + port, + ) + debugpy.wait_for_client() + + import pydevd + + py_db = pydevd.get_global_debugger() + thread = threading.current_thread() + additional_info = py_db.set_additional_thread_info(thread) + additional_info.is_tracing += 1 + try: + error = sys.exc_info() + py_db.stop_on_unhandled_exception(py_db, thread, additional_info, error) + finally: + additional_info.is_tracing -= 1 + + +if DEBUGPY_ENABLED: + _check_debugpy_installed() + sys.breakpointhook = breakpointhook diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 55f20deabe..816cc218bd 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -1061,7 +1061,16 @@ properties: minimum: 0 description: | The maximum number of erred tasks to remember. - + debugpy: + type: object + description: Configuration settings for Dask's remote debugger + properties: + enabled: + type: boolean + description: Enable remote debugging. + port: + type: integer + description: Port used by the debug adapter to listen on. p2p: type: object description: Configuration for P2P shuffles diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index a7624dd9ef..4f489cd8ed 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -305,6 +305,9 @@ distributed: - get_output_via_markers\.py erred-tasks: max-history: 100 + debugpy: + enabled: True + port: 5678 p2p: comm: diff --git a/distributed/worker.py b/distributed/worker.py index c8a4bbd574..3960598eb6 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -58,7 +58,7 @@ typename, ) -from distributed import preloading, profile, utils +from distributed import debugpy, preloading, profile, utils from distributed.batched import BatchedSend from distributed.collections import LRU from distributed.comm import Comm, connect, get_address_host, parse_address @@ -2995,6 +2995,7 @@ def _run_task_simple( msg: RunTaskFailure = error_message(e) # type: ignore msg["op"] = "task-erred" msg["actual_exception"] = e + debugpy.post_mortem() else: msg: RunTaskSuccess = { # type: ignore "op": "task-finished", @@ -3041,6 +3042,7 @@ async def _run_task_async( msg: RunTaskFailure = error_message(e) # type: ignore msg["op"] = "task-erred" msg["actual_exception"] = e + debugpy.post_mortem() else: msg: RunTaskSuccess = { # type: ignore "op": "task-finished",