From 02c1e43608329f86c8bd04c3b673604f82487425 Mon Sep 17 00:00:00 2001 From: Tom Augspurger <toaugspurger@nvidia.com> Date: Tue, 4 Feb 2025 11:47:03 -0800 Subject: [PATCH] Poll in test_rmm_metrics test xref https://github.com/rapidsai/dask-upstream-testing/issues/4 --- distributed/diagnostics/tests/test_rmm_diagnostics.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/distributed/diagnostics/tests/test_rmm_diagnostics.py b/distributed/diagnostics/tests/test_rmm_diagnostics.py index f26255e737a..8944a07e223 100644 --- a/distributed/diagnostics/tests/test_rmm_diagnostics.py +++ b/distributed/diagnostics/tests/test_rmm_diagnostics.py @@ -7,6 +7,7 @@ from dask import delayed from dask.utils import parse_bytes +from distributed.utils import Deadline from distributed.utils_test import gen_cluster pytestmark = pytest.mark.gpu @@ -32,6 +33,13 @@ async def test_rmm_metrics(c, s, *workers): assert w.metrics["rmm"]["rmm-total"] == parse_bytes("10MiB") result = delayed(rmm.DeviceBuffer)(size=10) result = result.persist() - await asyncio.sleep(1) + + deadline = Deadline.after(5) + + while not deadline.expired: + if w.metrics["rmm"]["rmm-used"] != 0: + break + await asyncio.sleep(0.25) + assert w.metrics["rmm"]["rmm-used"] != 0 assert w.metrics["rmm"]["rmm-total"] == parse_bytes("10MiB")