@@ -4702,3 +4702,125 @@ async def test_html_repr(c, s, a, b):
4702
4702
await asyncio .sleep (0.01 )
4703
4703
4704
4704
await f
4705
+
4706
+
4707
+ @pytest .mark .parametrize ("add_deps" , [False , True ])
4708
+ @gen_cluster (client = True , nthreads = [])
4709
+ async def test_resubmit_different_task_same_key (c , s , add_deps ):
4710
+ """If an intermediate key has a different run_spec (either the callable function or
4711
+ the dependencies / arguments) that will conflict with what was previously defined,
4712
+ it should raise an error since this can otherwise break in many different places and
4713
+ cause either spurious exceptions or even deadlocks.
4714
+
4715
+ For a real world example where this can trigger, see
4716
+ https://github.com/dask/dask/issues/9888
4717
+ """
4718
+ y1 = c .submit (inc , 1 , key = "y" )
4719
+
4720
+ x = delayed (inc )(1 , dask_key_name = "x" ) if add_deps else 2
4721
+ y2 = delayed (inc )(x , dask_key_name = "y" )
4722
+ z = delayed (inc )(y2 , dask_key_name = "z" )
4723
+
4724
+ if add_deps : # add_deps=True corrupts the state machine
4725
+ s .validate = False
4726
+
4727
+ with captured_logger ("distributed.scheduler" , level = logging .WARNING ) as log :
4728
+ fut = c .compute (z )
4729
+ await wait_for_state ("z" , "waiting" , s )
4730
+
4731
+ assert "Detected different `run_spec` for key 'y'" in log .getvalue ()
4732
+
4733
+ async with Worker (s .address ):
4734
+ if not add_deps : # add_deps=True hangs
4735
+ assert await y1 == 2
4736
+ assert await fut == 3
4737
+
4738
+
4739
+ @gen_cluster (client = True , nthreads = [])
4740
+ async def test_resubmit_different_task_same_key_many_clients (c , s ):
4741
+ """Two different clients submit a task with the same key but different run_spec's."""
4742
+ async with Client (s .address , asynchronous = True ) as c2 :
4743
+ with captured_logger ("distributed.scheduler" , level = logging .WARNING ) as log :
4744
+ x1 = c .submit (inc , 1 , key = "x" )
4745
+ x2 = c2 .submit (inc , 2 , key = "x" )
4746
+
4747
+ await wait_for_state ("x" , ("no-worker" , "queued" ), s )
4748
+ who_wants = s .tasks ["x" ].who_wants
4749
+ await async_poll_for (
4750
+ lambda : {cs .client_key for cs in who_wants } == {c .id , c2 .id }, timeout = 5
4751
+ )
4752
+
4753
+ assert "Detected different `run_spec` for key 'x'" in log .getvalue ()
4754
+
4755
+ async with Worker (s .address ):
4756
+ assert await x1 == 2
4757
+ assert await x2 == 2 # kept old run_spec
4758
+
4759
+
4760
+ @gen_cluster (client = True , nthreads = [])
4761
+ async def test_resubmit_nondeterministic_task_same_deps (c , s ):
4762
+ """Some run_specs can't be tokenized deterministically. Silently skip comparison on
4763
+ the run_spec in those cases. Dependencies must be the same.
4764
+ """
4765
+ o = object ()
4766
+ # Round-tripping `o` through two separate cloudpickle.dumps() calls generates two
4767
+ # different object instances, which yield different tokens.
4768
+ x1 = c .submit (lambda x : x , o , key = "x" )
4769
+ x2 = delayed (lambda x : x )(o , dask_key_name = "x" )
4770
+ y = delayed (lambda x : x )(x2 , dask_key_name = "y" )
4771
+ fut = c .compute (y )
4772
+ await async_poll_for (lambda : "y" in s .tasks , timeout = 5 )
4773
+ async with Worker (s .address ):
4774
+ assert type (await fut ) is object
4775
+
4776
+
4777
+ @pytest .mark .parametrize ("add_deps" , [False , True ])
4778
+ @gen_cluster (client = True , nthreads = [])
4779
+ async def test_resubmit_nondeterministic_task_different_deps (c , s , add_deps ):
4780
+ """Some run_specs can't be tokenized deterministically. Silently skip comparison on
4781
+ the run_spec in those cases. However, fail anyway if dependencies have changed.
4782
+ """
4783
+ o = object ()
4784
+ x1 = c .submit (inc , 1 , key = "x1" ) if not add_deps else 2
4785
+ x2 = c .submit (inc , 2 , key = "x2" )
4786
+ y1 = delayed (lambda i , j : i )(x1 , o , dask_key_name = "y" ).persist ()
4787
+ y2 = delayed (lambda i , j : i )(x2 , o , dask_key_name = "y" )
4788
+ z = delayed (inc )(y2 , dask_key_name = "z" )
4789
+
4790
+ if add_deps : # add_deps=True corrupts the state machine and hangs
4791
+ s .validate = False
4792
+
4793
+ with captured_logger ("distributed.scheduler" , level = logging .WARNING ) as log :
4794
+ fut = c .compute (z )
4795
+ await wait_for_state ("z" , "waiting" , s )
4796
+ assert "Detected different `run_spec` for key 'y'" in log .getvalue ()
4797
+
4798
+ if not add_deps : # add_deps=True corrupts the state machine and hangs
4799
+ async with Worker (s .address ):
4800
+ assert await fut == 3
4801
+
4802
+
4803
+ @gen_cluster (client = True , nthreads = [])
4804
+ async def test_resubmit_different_task_same_key_warns_only_once (c , s ):
4805
+ """If all tasks of a layer are affected by the same run_spec collision, warn
4806
+ only once.
4807
+ """
4808
+ x1s = c .map (inc , [0 , 1 , 2 ], key = [("x" , 0 ), ("x" , 1 ), ("x" , 2 )])
4809
+ dsk = {
4810
+ ("x" , 0 ): 3 ,
4811
+ ("x" , 1 ): 4 ,
4812
+ ("x" , 2 ): 5 ,
4813
+ ("y" , 0 ): (inc , ("x" , 0 )),
4814
+ ("y" , 1 ): (inc , ("x" , 1 )),
4815
+ ("y" , 2 ): (inc , ("x" , 2 )),
4816
+ }
4817
+ with captured_logger ("distributed.scheduler" , level = logging .WARNING ) as log :
4818
+ ys = c .get (dsk , [("y" , 0 ), ("y" , 1 ), ("y" , 2 )], sync = False )
4819
+ await wait_for_state (("y" , 2 ), "waiting" , s )
4820
+
4821
+ assert (
4822
+ len (re .findall ("Detected different `run_spec` for key " , log .getvalue ())) == 1
4823
+ )
4824
+
4825
+ async with Worker (s .address ):
4826
+ assert await c .gather (ys ) == [2 , 3 , 4 ]
0 commit comments