Skip to content

Commit 6a8f3b5

Browse files
committed
Prototypical implementation
1 parent 55bb639 commit 6a8f3b5

File tree

4 files changed

+100
-2
lines changed

4 files changed

+100
-2
lines changed

distributed/debugpy.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
import importlib.util
4+
import logging
5+
import sys
6+
import threading
7+
8+
import dask.config
9+
10+
logger = logging.getLogger(__name__)
11+
12+
DEBUGPY_ENABLED: bool = dask.config.get("distributed.diagnostics.debugpy.enabled")
13+
DEBUGPY_PORT: int = dask.config.get("distributed.diagnostics.debugpy.port")
14+
15+
16+
def _check_debugpy_installed():
17+
if importlib.util.find_spec("debugpy") is None:
18+
raise ModuleNotFoundError(
19+
"Dask debugger requires debugpy. Please make sure it is installed."
20+
)
21+
22+
23+
LOCK = threading.Lock()
24+
25+
26+
def _ensure_debugpy_listens() -> tuple[str, int]:
27+
import debugpy
28+
29+
from distributed.worker import get_worker
30+
31+
worker = get_worker()
32+
33+
with LOCK:
34+
if endpoint := worker.extensions.get("debugpy", None):
35+
return endpoint
36+
endpoint = debugpy.listen(DEBUGPY_PORT)
37+
worker.extensions["debugpy"] = endpoint
38+
return endpoint
39+
40+
41+
def breakpointhook() -> None:
42+
import debugpy
43+
44+
host, port = _ensure_debugpy_listens()
45+
if not debugpy.is_client_connected():
46+
logger.warning(
47+
"Breakpoint encountered; waiting for client to attach to %s:%d...",
48+
host,
49+
port,
50+
)
51+
debugpy.wait_for_client()
52+
53+
debugpy.breakpoint()
54+
55+
56+
def post_mortem() -> None:
57+
# Based on https://github.com/microsoft/debugpy/issues/723
58+
import debugpy
59+
60+
host, port = _ensure_debugpy_listens()
61+
if not debugpy.is_client_connected():
62+
logger.warning(
63+
"Exception encountered; waiting for client to attach to %s:%d...",
64+
host,
65+
port,
66+
)
67+
debugpy.wait_for_client()
68+
69+
import pydevd
70+
71+
py_db = pydevd.get_global_debugger()
72+
thread = threading.current_thread()
73+
additional_info = py_db.set_additional_thread_info(thread)
74+
additional_info.is_tracing += 1
75+
try:
76+
error = sys.exc_info()
77+
py_db.stop_on_unhandled_exception(py_db, thread, additional_info, error)
78+
finally:
79+
additional_info.is_tracing -= 1
80+
81+
82+
if DEBUGPY_ENABLED:
83+
_check_debugpy_installed()
84+
sys.breakpointhook = breakpoint

distributed/distributed-schema.yaml

+10-1
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,16 @@ properties:
10611061
minimum: 0
10621062
description: |
10631063
The maximum number of erred tasks to remember.
1064-
1064+
debugpy:
1065+
type: object
1066+
description: Configuration settings for Dask's remote debugger
1067+
properties:
1068+
enabled:
1069+
type: boolean
1070+
description: Enable remote debugging.
1071+
port:
1072+
type: integer
1073+
description: Port used by the debug adapter to listen on.
10651074
p2p:
10661075
type: object
10671076
description: Configuration for P2P shuffles

distributed/distributed.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ distributed:
305305
- get_output_via_markers\.py
306306
erred-tasks:
307307
max-history: 100
308+
debugpy:
309+
enabled: True
310+
port: 5678
308311

309312
p2p:
310313
comm:

distributed/worker.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
typename,
5959
)
6060

61-
from distributed import preloading, profile, utils
61+
from distributed import debugpy, preloading, profile, utils
6262
from distributed.batched import BatchedSend
6363
from distributed.collections import LRU
6464
from distributed.comm import Comm, connect, get_address_host, parse_address
@@ -2995,6 +2995,7 @@ def _run_task_simple(
29952995
msg: RunTaskFailure = error_message(e) # type: ignore
29962996
msg["op"] = "task-erred"
29972997
msg["actual_exception"] = e
2998+
debugpy.post_mortem()
29982999
else:
29993000
msg: RunTaskSuccess = { # type: ignore
30003001
"op": "task-finished",
@@ -3041,6 +3042,7 @@ async def _run_task_async(
30413042
msg: RunTaskFailure = error_message(e) # type: ignore
30423043
msg["op"] = "task-erred"
30433044
msg["actual_exception"] = e
3045+
debugpy.post_mortem()
30443046
else:
30453047
msg: RunTaskSuccess = { # type: ignore
30463048
"op": "task-finished",

0 commit comments

Comments
 (0)