Skip to content

Commit cad1467

Browse files
committed
Warn if tasks are submitted with identical keys but different run_spec
1 parent 045dc64 commit cad1467

File tree

4 files changed

+167
-5
lines changed

4 files changed

+167
-5
lines changed

.github/workflows/tests.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ jobs:
154154
# Increase this value to reset cache if
155155
# continuous_integration/environment-${{ matrix.environment }}.yaml has not
156156
# changed. See also same variable in .pre-commit-config.yaml
157-
CACHE_NUMBER: 0
157+
CACHE_NUMBER: 1
158158
id: cache
159159

160160
- name: Update environment

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,4 @@ repos:
6969

7070
# Increase this value to clear the cache on GitHub actions if nothing else in this file
7171
# has changed. See also same variable in .github/workflows/test.yaml
72-
# CACHE_NUMBER: 0
72+
# CACHE_NUMBER: 1

distributed/scheduler.py

+43-3
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
import dask
5555
import dask.utils
56+
from dask.base import TokenizationError, normalize_token, tokenize
5657
from dask.core import get_deps, iskey, validate_key
5758
from dask.typing import Key, no_default
5859
from dask.utils import (
@@ -4752,6 +4753,7 @@ def _generate_taskstates(
47524753
stack = list(keys)
47534754
touched_keys = set()
47544755
touched_tasks = []
4756+
tgs_with_bad_run_spec = set()
47554757
while stack:
47564758
k = stack.pop()
47574759
if k in touched_keys:
@@ -4772,9 +4774,47 @@ def _generate_taskstates(
47724774
# run_spec in the submitted graph may be None. This happens
47734775
# when an already persisted future is part of the graph
47744776
elif k in dsk:
4775-
# TODO run a health check to verify that run_spec and dependencies
4776-
# did not change. See https://github.com/dask/distributed/pull/8185
4777-
pass
4777+
try:
4778+
tok_lhs: Any = tokenize(ts.run_spec, ensure_deterministic=True)
4779+
tok_rhs: Any = tokenize(dsk[k], ensure_deterministic=True)
4780+
except TokenizationError:
4781+
# Non-deterministic tokens; skip comparison
4782+
tok_lhs = tok_rhs = None
4783+
4784+
# Additionally check dependency names. This should only be necessary
4785+
# if run_specs can't be tokenized deterministically.
4786+
deps_lhs = {dts.key for dts in ts.dependencies}
4787+
deps_rhs = dependencies.get(k, set())
4788+
4789+
# FIXME It would be a really healthy idea to change this to a hard
4790+
# failure. However, this is not possible at the moment because of
4791+
# https://github.com/dask/dask/issues/9888
4792+
if (
4793+
tok_lhs != tok_rhs or deps_lhs != deps_rhs
4794+
) and ts.group not in tgs_with_bad_run_spec:
4795+
tgs_with_bad_run_spec.add(ts.group)
4796+
logger.warning(
4797+
f"Detected different `run_spec` for key {ts.key!r} between two "
4798+
"consecutive calls to `update_graph`. This can cause failures "
4799+
"and deadlocks down the line. Please ensure unique key names. "
4800+
"If you are using a standard dask collections, consider "
4801+
"releasing all the data before resubmitting another "
4802+
"computation. More details and help can be found at "
4803+
"https://github.com/dask/dask/issues/9888. "
4804+
+ textwrap.dedent(
4805+
f"""
4806+
Debugging information
4807+
---------------------
4808+
old task state: {ts.state}
4809+
old run_spec: {ts.run_spec!r}
4810+
new run_spec: {dsk[k]!r}
4811+
old token: {normalize_token(ts.run_spec)!r}
4812+
new token: {normalize_token(dsk[k])!r}
4813+
old dependencies: {deps_lhs}
4814+
new dependencies: {deps_rhs}
4815+
"""
4816+
)
4817+
)
47784818

47794819
if ts.run_spec:
47804820
runnable.append(ts)

distributed/tests/test_scheduler.py

+122
Original file line numberDiff line numberDiff line change
@@ -4702,3 +4702,125 @@ async def test_html_repr(c, s, a, b):
47024702
await asyncio.sleep(0.01)
47034703

47044704
await f
4705+
4706+
4707+
@pytest.mark.parametrize("add_deps", [False, True])
4708+
@gen_cluster(client=True, nthreads=[])
4709+
async def test_resubmit_different_task_same_key(c, s, add_deps):
4710+
"""If an intermediate key has a different run_spec (either the callable function or
4711+
the dependencies / arguments) that will conflict with what was previously defined,
4712+
it should raise an error since this can otherwise break in many different places and
4713+
cause either spurious exceptions or even deadlocks.
4714+
4715+
For a real world example where this can trigger, see
4716+
https://github.com/dask/dask/issues/9888
4717+
"""
4718+
y1 = c.submit(inc, 1, key="y")
4719+
4720+
x = delayed(inc)(1, dask_key_name="x") if add_deps else 2
4721+
y2 = delayed(inc)(x, dask_key_name="y")
4722+
z = delayed(inc)(y2, dask_key_name="z")
4723+
4724+
if add_deps: # add_deps=True corrupts the state machine
4725+
s.validate = False
4726+
4727+
with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
4728+
fut = c.compute(z)
4729+
await wait_for_state("z", "waiting", s)
4730+
4731+
assert "Detected different `run_spec` for key 'y'" in log.getvalue()
4732+
4733+
async with Worker(s.address):
4734+
if not add_deps: # add_deps=True hangs
4735+
assert await y1 == 2
4736+
assert await fut == 3
4737+
4738+
4739+
@gen_cluster(client=True, nthreads=[])
4740+
async def test_resubmit_different_task_same_key_many_clients(c, s):
4741+
"""Two different clients submit a task with the same key but different run_spec's."""
4742+
async with Client(s.address, asynchronous=True) as c2:
4743+
with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
4744+
x1 = c.submit(inc, 1, key="x")
4745+
x2 = c2.submit(inc, 2, key="x")
4746+
4747+
await wait_for_state("x", ("no-worker", "queued"), s)
4748+
who_wants = s.tasks["x"].who_wants
4749+
await async_poll_for(
4750+
lambda: {cs.client_key for cs in who_wants} == {c.id, c2.id}, timeout=5
4751+
)
4752+
4753+
assert "Detected different `run_spec` for key 'x'" in log.getvalue()
4754+
4755+
async with Worker(s.address):
4756+
assert await x1 == 2
4757+
assert await x2 == 2 # kept old run_spec
4758+
4759+
4760+
@gen_cluster(client=True, nthreads=[])
4761+
async def test_resubmit_nondeterministic_task_same_deps(c, s):
4762+
"""Some run_specs can't be tokenized deterministically. Silently skip comparison on
4763+
the run_spec in those cases. Dependencies must be the same.
4764+
"""
4765+
o = object()
4766+
# Round-tripping `o` through two separate cloudpickle.dumps() calls generates two
4767+
# different object instances, which yield different tokens.
4768+
x1 = c.submit(lambda x: x, o, key="x")
4769+
x2 = delayed(lambda x: x)(o, dask_key_name="x")
4770+
y = delayed(lambda x: x)(x2, dask_key_name="y")
4771+
fut = c.compute(y)
4772+
await async_poll_for(lambda: "y" in s.tasks, timeout=5)
4773+
async with Worker(s.address):
4774+
assert type(await fut) is object
4775+
4776+
4777+
@pytest.mark.parametrize("add_deps", [False, True])
4778+
@gen_cluster(client=True, nthreads=[])
4779+
async def test_resubmit_nondeterministic_task_different_deps(c, s, add_deps):
4780+
"""Some run_specs can't be tokenized deterministically. Silently skip comparison on
4781+
the run_spec in those cases. However, fail anyway if dependencies have changed.
4782+
"""
4783+
o = object()
4784+
x1 = c.submit(inc, 1, key="x1") if not add_deps else 2
4785+
x2 = c.submit(inc, 2, key="x2")
4786+
y1 = delayed(lambda i, j: i)(x1, o, dask_key_name="y").persist()
4787+
y2 = delayed(lambda i, j: i)(x2, o, dask_key_name="y")
4788+
z = delayed(inc)(y2, dask_key_name="z")
4789+
4790+
if add_deps: # add_deps=True corrupts the state machine and hangs
4791+
s.validate = False
4792+
4793+
with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
4794+
fut = c.compute(z)
4795+
await wait_for_state("z", "waiting", s)
4796+
assert "Detected different `run_spec` for key 'y'" in log.getvalue()
4797+
4798+
if not add_deps: # add_deps=True corrupts the state machine and hangs
4799+
async with Worker(s.address):
4800+
assert await fut == 3
4801+
4802+
4803+
@gen_cluster(client=True, nthreads=[])
4804+
async def test_resubmit_different_task_same_key_warns_only_once(c, s):
4805+
"""If all tasks of a layer are affected by the same run_spec collision, warn
4806+
only once.
4807+
"""
4808+
x1s = c.map(inc, [0, 1, 2], key=[("x", 0), ("x", 1), ("x", 2)])
4809+
dsk = {
4810+
("x", 0): 3,
4811+
("x", 1): 4,
4812+
("x", 2): 5,
4813+
("y", 0): (inc, ("x", 0)),
4814+
("y", 1): (inc, ("x", 1)),
4815+
("y", 2): (inc, ("x", 2)),
4816+
}
4817+
with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
4818+
ys = c.get(dsk, [("y", 0), ("y", 1), ("y", 2)], sync=False)
4819+
await wait_for_state(("y", 2), "waiting", s)
4820+
4821+
assert (
4822+
len(re.findall("Detected different `run_spec` for key ", log.getvalue())) == 1
4823+
)
4824+
4825+
async with Worker(s.address):
4826+
assert await c.gather(ys) == [2, 3, 4]

0 commit comments

Comments
 (0)