diff --git a/distributed/diagnostics/tests/test_rmm_diagnostics.py b/distributed/diagnostics/tests/test_rmm_diagnostics.py index f26255e737a..8944a07e223 100644 --- a/distributed/diagnostics/tests/test_rmm_diagnostics.py +++ b/distributed/diagnostics/tests/test_rmm_diagnostics.py @@ -7,6 +7,7 @@ from dask import delayed from dask.utils import parse_bytes +from distributed.utils import Deadline from distributed.utils_test import gen_cluster pytestmark = pytest.mark.gpu @@ -32,6 +33,13 @@ async def test_rmm_metrics(c, s, *workers): assert w.metrics["rmm"]["rmm-total"] == parse_bytes("10MiB") result = delayed(rmm.DeviceBuffer)(size=10) result = result.persist() - await asyncio.sleep(1) + + deadline = Deadline.after(5) + + while not deadline.expired: + if w.metrics["rmm"]["rmm-used"] != 0: + break + await asyncio.sleep(0.25) + assert w.metrics["rmm"]["rmm-used"] != 0 assert w.metrics["rmm"]["rmm-total"] == parse_bytes("10MiB")