Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove expensive tokenization for key uniqueness check #9009

Merged
merged 3 commits into from
Feb 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
@@ -55,7 +55,6 @@
import dask
import dask.utils
from dask._task_spec import DependenciesMapping, GraphNode, convert_legacy_graph
from dask.base import TokenizationError, normalize_token, tokenize
from dask.core import istask, validate_key
from dask.typing import Key, no_default
from dask.utils import (
@@ -4985,25 +4984,14 @@ def _generate_taskstates(
# run_spec in the submitted graph may be None. This happens
# when an already persisted future is part of the graph
elif k in dsk:
# If both tokens are non-deterministic, skip comparison
try:
tok_lhs = tokenize(ts.run_spec, ensure_deterministic=True)
except TokenizationError:
tok_lhs = ""
try:
tok_rhs = tokenize(dsk[k], ensure_deterministic=True)
except TokenizationError:
tok_rhs = ""

# Additionally check dependency names. This should only be necessary
# if run_specs can't be tokenized deterministically.
# Check dependency names.
deps_lhs = {dts.key for dts in ts.dependencies}
deps_rhs = dependencies[k]

# FIXME It would be a really healthy idea to change this to a hard
# failure. However, this is not possible at the moment because of
# https://github.com/dask/dask/issues/9888
if tok_lhs != tok_rhs or deps_lhs != deps_rhs:
if deps_lhs != deps_rhs:
# Retain old run_spec and dependencies; rerun them if necessary.
# This sweeps the issue of collision under the carpet as long as the
# old and new task produce the same output - such as in
@@ -5029,8 +5017,6 @@ def _generate_taskstates(
old task state: {ts.state}
old run_spec: {ts.run_spec!r}
new run_spec: {dsk[k]!r}
old token: {normalize_token(ts.run_spec)!r}
new token: {normalize_token(dsk[k])!r}
old dependencies: {deps_lhs}
new dependencies: {deps_rhs}
"""
174 changes: 0 additions & 174 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -4946,147 +4946,6 @@ async def test_html_repr(c, s, a, b):
await f


@pytest.mark.parametrize("deps", ["same", "less", "more"])
@gen_cluster(client=True, nthreads=[])
async def test_resubmit_different_task_same_key_before_previous_is_done(c, s, deps):
"""If an intermediate key has a different run_spec (either the callable function or
the dependencies / arguments) that will conflict with what was previously defined,
it should raise an error since this can otherwise break in many different places and
cause either spurious exceptions or even deadlocks.

In this specific test, the previous run_spec has not been computed yet.
See also test_resubmit_different_task_same_key_after_previous_is_done.

For a real world example where this can trigger, see
https://github.com/dask/dask/issues/9888
"""
seen = False

def _match(event):
_, msg = event
return (
isinstance(msg, dict)
and msg.get("action", None) == "update-graph"
and msg["metrics"]["key_collisions"] > 0
)

def handler(ev):
if _match(ev):
nonlocal seen
seen = True

c.subscribe_topic("scheduler", handler)

x1 = c.submit(inc, 1, key="x1")
y_old = c.submit(inc, x1, key="y")

x1b = x1 if deps != "less" else 2
x2 = delayed(inc)(10, dask_key_name="x2") if deps == "more" else 11
y_new = delayed(sum)([x1b, x2], dask_key_name="y")
z = delayed(inc)(y_new, dask_key_name="z")

with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
fut = c.compute(z)
await wait_for_state("z", "waiting", s)

assert "Detected different `run_spec` for key 'y'" in log.getvalue()

await async_poll_for(lambda: seen, timeout=5)

async with Worker(s.address):
# Used old run_spec
assert await y_old == 3
assert await fut == 4


@pytest.mark.parametrize("deps", ["same", "less", "more"])
@pytest.mark.parametrize("release_previous", [False, True])
@gen_cluster(client=True)
async def test_resubmit_different_task_same_key_after_previous_is_done(
c, s, a, b, deps, release_previous
):
"""Same as test_resubmit_different_task_same_key, but now the replaced task has
already been computed and is either in memory or released, and so are its old
dependencies, so they may need to be recomputed.
"""
x1 = delayed(inc)(1, dask_key_name="x1")
x1fut = c.compute(x1)
y_old = c.submit(inc, x1fut, key="y")
z1 = c.submit(inc, y_old, key="z1")
await wait(z1)
if release_previous:
del x1fut, y_old
await wait_for_state("x1", "released", s)
await wait_for_state("y", "released", s)

x1b = x1 if deps != "less" else 2
x2 = delayed(inc)(10, dask_key_name="x2") if deps == "more" else 11
y_new = delayed(sum)([x1b, x2], dask_key_name="y")
z2 = delayed(inc)(y_new, dask_key_name="z2")

with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
fut = c.compute(z2)
# Used old run_spec
assert await fut == 4
assert "x2" not in s.tasks

# _generate_taskstates won't run for a dependency that's already in memory
has_warning = "Detected different `run_spec` for key 'y'" in log.getvalue()
assert has_warning is (release_previous or deps == "less")


@gen_cluster(client=True, nthreads=[])
async def test_resubmit_different_task_same_key_many_clients(c, s):
"""Two different clients submit a task with the same key but different run_spec's."""
async with Client(s.address, asynchronous=True) as c2:
with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
x1 = c.submit(inc, 1, key="x")
x2 = c2.submit(inc, 2, key="x")

await wait_for_state("x", ("no-worker", "queued"), s)
who_wants = s.tasks["x"].who_wants
await async_poll_for(
lambda: {cs.client_key for cs in who_wants} == {c.id, c2.id}, timeout=5
)

assert "Detected different `run_spec` for key 'x'" in log.getvalue()

async with Worker(s.address):
assert await x1 == 2
assert await x2 == 2 # kept old run_spec


@pytest.mark.parametrize(
"before,after,expect_msg",
[
(object(), 123, True),
(123, object(), True),
(o := object(), o, False),
],
)
@gen_cluster(client=True, nthreads=[])
async def test_resubmit_nondeterministic_task_same_deps(
c, s, before, after, expect_msg
):
"""Some run_specs can't be tokenized deterministically. Silently skip comparison on
the run_spec when both lhs and rhs are nondeterministic.
Dependencies must be the same.
"""
x1 = c.submit(lambda x: x, before, key="x")
x2 = delayed(lambda x: x)(after, dask_key_name="x")
y = delayed(lambda x: x)(x2, dask_key_name="y")

with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
fut = c.compute(y)
await async_poll_for(lambda: "y" in s.tasks, timeout=5)

has_msg = "Detected different `run_spec` for key 'x'" in log.getvalue()
assert has_msg == expect_msg

async with Worker(s.address):
assert type(await fut) is type(before)


@pytest.mark.parametrize("add_deps", [False, True])
@gen_cluster(client=True, nthreads=[])
async def test_resubmit_nondeterministic_task_different_deps(c, s, add_deps):
@@ -5109,39 +4968,6 @@ async def test_resubmit_nondeterministic_task_different_deps(c, s, add_deps):
assert await fut == 3


@pytest.mark.parametrize(
"loglevel,expect_loglines", [(logging.DEBUG, 2), (logging.WARNING, 1)]
)
@gen_cluster(client=True, nthreads=[])
async def test_resubmit_different_task_same_key_warns_only_once(
c, s, loglevel, expect_loglines
):
"""If all tasks of a layer are affected by the same run_spec collision, warn
only once.
"""
y1s = c.map(inc, [0, 1, 2], key=[("y", 0), ("y", 1), ("y", 2)])
dsk = {
"x": 3,
("y", 0): (inc, "x"), # run_spec and dependencies change
("y", 1): (inc, 4), # run_spec changes, dependencies don't
("y", 2): (inc, 2), # Doesn't change
("z", 0): (inc, ("y", 0)),
("z", 1): (inc, ("y", 1)),
("z", 2): (inc, ("y", 2)),
}
with captured_logger("distributed.scheduler", level=loglevel) as log:
zs = c.get(dsk, [("z", 0), ("z", 1), ("z", 2)], sync=False)
await wait_for_state(("z", 2), "waiting", s)

actual_loglines = len(
re.findall("Detected different `run_spec` for key ", log.getvalue())
)
assert actual_loglines == expect_loglines

async with Worker(s.address):
assert await c.gather(zs) == [2, 3, 4] # Kept old ys


def block(x, in_event, block_event):
in_event.set()
block_event.wait()