From d21a17925a6b8461d6ef4fff1a4b5b370a406162 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 6 Feb 2024 17:37:58 +0000 Subject: [PATCH 1/2] Tweaks to update_graph (backport from #8185) --- distributed/client.py | 2 +- distributed/scheduler.py | 64 ++++++++++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 004f64c2834..cb522630b2e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1576,7 +1576,7 @@ async def _handle_report(self): breakout = False for msg in msgs: - logger.debug("Client receives message %s", msg) + logger.debug("Client %s receives message %s", self.id, msg) if "status" in msg and "error" in msg["status"]: typ, exc, tb = clean_exception(**msg) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1b51a0e34bc..6e131e00c41 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -52,7 +52,8 @@ from tornado.ioloop import IOLoop import dask -from dask.core import get_deps, validate_key +import dask.utils +from dask.core import get_deps, iskey, validate_key from dask.typing import Key, no_default from dask.utils import ( ensure_dict, @@ -4721,6 +4722,7 @@ async def update_graph( stimulus_id=stimulus_id or f"update-graph-{start}", ) except RuntimeError as e: + logger.error(str(e)) err = error_message(e) for key in keys: self.report( @@ -4729,7 +4731,10 @@ async def update_graph( "key": key, "exception": err["exception"], "traceback": err["traceback"], - } + }, + # This informs all clients in who_wants plus the current client + # (which may not have been added to who_wants yet) + client=client, ) end = time() self.digest_metric("update-graph-duration", end - start) @@ -4755,8 +4760,21 @@ def _generate_taskstates( if ts is None: ts = self.new_task(k, dsk.get(k), "released", computation=computation) new_tasks.append(ts) - elif not ts.run_spec: + # It is possible to create the TaskState object before its runspec is known + # to the scheduler. For instance, this is possible when using a Variable: + # `f = c.submit(foo); await Variable().set(f)` since the Variable uses a + # different comm channel, so the `client_desires_key` message could arrive + # before `update_graph`. + # There are also anti-pattern processes possible; + # see for example test_scatter_creates_ts + elif ts.run_spec is None: ts.run_spec = dsk.get(k) + # run_spec in the submitted graph may be None. This happens + # when an already persisted future is part of the graph + elif k in dsk: + # TODO run a health check to verify that run_spec and dependencies + # did not change. See https://github.com/dask/distributed/pull/8185 + pass if ts.run_spec: runnable.append(ts) @@ -5538,28 +5556,28 @@ def report( tasks: dict = self.tasks ts = tasks.get(msg_key) - client_comms: dict = self.client_comms - if ts is None: + if ts is None and client is None: # Notify all clients - client_keys = list(client_comms) - elif client: - # Notify clients interested in key - client_keys = [cs.client_key for cs in ts.who_wants or ()] + client_keys = list(self.client_comms) + elif ts is None: + client_keys = [client] else: # Notify clients interested in key (including `client`) + # Note that, if report() was called by update_graph(), `client` won't be in + # ts.who_wants yet. client_keys = [ cs.client_key for cs in ts.who_wants or () if cs.client_key != client ] - client_keys.append(client) + if client is not None: + client_keys.append(client) - k: str for k in client_keys: - c = client_comms.get(k) + c = self.client_comms.get(k) if c is None: continue try: c.send(msg) - # logger.debug("Scheduler sends message to client %s", msg) + # logger.debug("Scheduler sends message to client %s: %s", k, msg) except CommClosedError: if self.status == Status.running: logger.critical( @@ -8724,26 +8742,28 @@ def _materialize_graph( dsk2 = {} fut_deps = {} for k, v in dsk.items(): - dsk2[k], futs = unpack_remotedata(v, byte_keys=True) + v, futs = unpack_remotedata(v, byte_keys=True) if futs: fut_deps[k] = futs + + # Remove aliases {x: x}. + # FIXME: This is an artifact generated by unpack_remotedata when using persisted + # collections. There should be a better way to achieve that tasks are not self + # referencing themselves. + if not iskey(v) or v != k: + dsk2[k] = v + dsk = dsk2 # - Add in deps for any tasks that depend on futures for k, futures in fut_deps.items(): - dependencies[k].update(f.key for f in futures) + dependencies[k].update(f.key for f in futures if f.key != k) # Remove any self-dependencies (happens on test_publish_bag() and others) for k, v in dependencies.items(): deps = set(v) - if k in deps: - deps.remove(k) + deps.discard(k) dependencies[k] = deps - # Remove aliases - for k in list(dsk): - if dsk[k] is k: - del dsk[k] dsk = valmap(_normalize_task, dsk) - return dsk, dependencies, annotations_by_type From 471715cbbf9419e21cbae64863addff726a64c2e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 Feb 2024 11:28:27 +0000 Subject: [PATCH 2/2] Update distributed/scheduler.py --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6e131e00c41..be11bcd9ec8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8757,7 +8757,7 @@ def _materialize_graph( # - Add in deps for any tasks that depend on futures for k, futures in fut_deps.items(): - dependencies[k].update(f.key for f in futures if f.key != k) + dependencies[k].update(f.key for f in futures) # Remove any self-dependencies (happens on test_publish_bag() and others) for k, v in dependencies.items():