Skip to content

Commit 82716fb

Browse files
Add regression test mcmc seeding with Generators
Co-authored-by: ricardoV94 <[email protected]>
1 parent 4b2cc5b commit 82716fb

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/sampling/test_parallel.py

+20
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,23 @@ def logp(x, mu):
228228
with warnings.catch_warnings():
229229
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
230230
pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
231+
232+
233+
@pytest.mark.parametrize("cores", (1, 2))
234+
def test_sampling_with_random_generator_matches(cores):
235+
# Regression test for https://github.com/pymc-devs/pymc/issues/7612
236+
kwargs = {
237+
"chains": 2,
238+
"cores": cores,
239+
"tune": 10,
240+
"draws": 10,
241+
"compute_convergence_checks": False,
242+
"progress_bar": False,
243+
}
244+
with pm.Model() as m:
245+
x = pm.Normal("x")
246+
247+
post1 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior
248+
post2 = pm.sample(random_seed=np.random.default_rng(42), **kwargs).posterior
249+
250+
assert post1.equals(post2), (post1["x"].mean().item(), post2["x"].mean().item())

0 commit comments

Comments
 (0)