Skip to content

Commit 450477e

Browse files
committed
Stabilise flaky test
1 parent f32f081 commit 450477e

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

tests/test_mask.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ def test_convgnp_mask(nps):
1414
conv_receptive_field=0.5,
1515
conv_layers=1,
1616
conv_channels=1,
17-
# Dividing by the density channel makes the forward very sensitive to the
18-
# numerics.
19-
divide_by_density=False,
17+
# A large margin and `float64`s help with numerical stability.
18+
margin=2,
19+
dtype=nps.dtype64,
2020
)
21-
xc, yc, xt, yt = generate_data(nps)
21+
xc, yc, xt, yt = generate_data(nps, dtype=nps.dtype64)
2222

2323
# Predict without the final three points.
2424
pred = model(xc[:, :, :-3], yc[:, :, :-3], xt)

tests/util.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def approx(
5151
nps_tf.dtype64 = tf.float64
5252

5353

54-
@pytest.fixture(params=[nps_torch, nps_tf], scope="module")
54+
@pytest.fixture(params=[nps_tf, nps_torch], scope="module")
5555
def nps(request):
5656
return request.param
5757

@@ -64,14 +64,17 @@ def generate_data(
6464
n_context=5,
6565
n_target=7,
6666
binary=False,
67+
dtype=None,
6768
):
68-
xc = B.randn(nps.dtype, batch_size, dim_x, n_context)
69-
yc = B.randn(nps.dtype, batch_size, dim_y, n_context)
70-
xt = B.randn(nps.dtype, batch_size, dim_x, n_target)
71-
yt = B.randn(nps.dtype, batch_size, dim_y, n_target)
69+
if dtype is None:
70+
dtype = nps.dtype
71+
xc = B.randn(dtype, batch_size, dim_x, n_context)
72+
yc = B.randn(dtype, batch_size, dim_y, n_context)
73+
xt = B.randn(dtype, batch_size, dim_x, n_target)
74+
yt = B.randn(dtype, batch_size, dim_y, n_target)
7275
if binary:
73-
yc = B.cast(nps.dtype, yc >= 0)
74-
yt = B.cast(nps.dtype, yt >= 0)
76+
yc = B.cast(dtype, yc >= 0)
77+
yt = B.cast(dtype, yt >= 0)
7578
return xc, yc, xt, yt
7679

7780

0 commit comments

Comments
 (0)