diff --git a/distributed/diagnostics/tests/test_rmm_diagnostics.py b/distributed/diagnostics/tests/test_rmm_diagnostics.py
index f26255e737a..8944a07e223 100644
--- a/distributed/diagnostics/tests/test_rmm_diagnostics.py
+++ b/distributed/diagnostics/tests/test_rmm_diagnostics.py
@@ -7,6 +7,7 @@
 from dask import delayed
 from dask.utils import parse_bytes
 
+from distributed.utils import Deadline
 from distributed.utils_test import gen_cluster
 
 pytestmark = pytest.mark.gpu
@@ -32,6 +33,13 @@ async def test_rmm_metrics(c, s, *workers):
     assert w.metrics["rmm"]["rmm-total"] == parse_bytes("10MiB")
     result = delayed(rmm.DeviceBuffer)(size=10)
     result = result.persist()
-    await asyncio.sleep(1)
+
+    deadline = Deadline.after(5)
+
+    while not deadline.expired:
+        if w.metrics["rmm"]["rmm-used"] != 0:
+            break
+        await asyncio.sleep(0.25)
+
     assert w.metrics["rmm"]["rmm-used"] != 0
     assert w.metrics["rmm"]["rmm-total"] == parse_bytes("10MiB")