From 02c1e43608329f86c8bd04c3b673604f82487425 Mon Sep 17 00:00:00 2001
From: Tom Augspurger <toaugspurger@nvidia.com>
Date: Tue, 4 Feb 2025 11:47:03 -0800
Subject: [PATCH] Poll in test_rmm_metrics test

xref https://github.com/rapidsai/dask-upstream-testing/issues/4
---
 distributed/diagnostics/tests/test_rmm_diagnostics.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/distributed/diagnostics/tests/test_rmm_diagnostics.py b/distributed/diagnostics/tests/test_rmm_diagnostics.py
index f26255e737a..8944a07e223 100644
--- a/distributed/diagnostics/tests/test_rmm_diagnostics.py
+++ b/distributed/diagnostics/tests/test_rmm_diagnostics.py
@@ -7,6 +7,7 @@
 from dask import delayed
 from dask.utils import parse_bytes
 
+from distributed.utils import Deadline
 from distributed.utils_test import gen_cluster
 
 pytestmark = pytest.mark.gpu
@@ -32,6 +33,13 @@ async def test_rmm_metrics(c, s, *workers):
     assert w.metrics["rmm"]["rmm-total"] == parse_bytes("10MiB")
     result = delayed(rmm.DeviceBuffer)(size=10)
     result = result.persist()
-    await asyncio.sleep(1)
+
+    deadline = Deadline.after(5)
+
+    while not deadline.expired:
+        if w.metrics["rmm"]["rmm-used"] != 0:
+            break
+        await asyncio.sleep(0.25)
+
     assert w.metrics["rmm"]["rmm-used"] != 0
     assert w.metrics["rmm"]["rmm-total"] == parse_bytes("10MiB")