Skip to content

Commit 4b2cc5b

Browse files
lucianopazricardoV94
authored andcommitted
Ensure parallel sampling does not lose BitGenerator state
1 parent 6b3bda0 commit 4b2cc5b

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

pymc/sampling/parallel.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@
3333

3434
from pymc.blocking import DictToArrayBijection
3535
from pymc.exceptions import SamplingError
36-
from pymc.util import CustomProgress, default_progress_theme
36+
from pymc.util import (
37+
CustomProgress,
38+
RandomGeneratorState,
39+
default_progress_theme,
40+
get_state_from_generator,
41+
random_generator_from_state,
42+
)
3743

3844
logger = logging.getLogger(__name__)
3945

@@ -96,13 +102,12 @@ def __init__(
96102
shared_point,
97103
draws: int,
98104
tune: int,
99-
rng: np.random.Generator,
100-
seed_seq: np.random.SeedSequence,
105+
rng_state: RandomGeneratorState,
101106
blas_cores,
102107
):
103108
# For some strange reason, spawn multiprocessing doesn't copy the rng
104109
# seed sequence, so we have to rebuild it from scratch
105-
rng = np.random.Generator(type(rng.bit_generator)(seed_seq))
110+
rng = random_generator_from_state(rng_state)
106111
self._msg_pipe = msg_pipe
107112
self._step_method = step_method
108113
self._step_method_is_pickled = step_method_is_pickled
@@ -263,8 +268,7 @@ def __init__(
263268
self._shared_point,
264269
draws,
265270
tune,
266-
rng,
267-
rng.bit_generator.seed_seq,
271+
get_state_from_generator(rng),
268272
blas_cores,
269273
),
270274
)

0 commit comments

Comments
 (0)