|
33 | 33 |
|
34 | 34 | from pymc.blocking import DictToArrayBijection
|
35 | 35 | from pymc.exceptions import SamplingError
|
36 |
| -from pymc.util import CustomProgress, default_progress_theme |
| 36 | +from pymc.util import ( |
| 37 | + CustomProgress, |
| 38 | + RandomGeneratorState, |
| 39 | + default_progress_theme, |
| 40 | + get_state_from_generator, |
| 41 | + random_generator_from_state, |
| 42 | +) |
37 | 43 |
|
38 | 44 | logger = logging.getLogger(__name__)
|
39 | 45 |
|
@@ -96,13 +102,12 @@ def __init__(
|
96 | 102 | shared_point,
|
97 | 103 | draws: int,
|
98 | 104 | tune: int,
|
99 |
| - rng: np.random.Generator, |
100 |
| - seed_seq: np.random.SeedSequence, |
| 105 | + rng_state: RandomGeneratorState, |
101 | 106 | blas_cores,
|
102 | 107 | ):
|
103 | 108 | # For some strange reason, spawn multiprocessing doesn't copy the rng
|
104 | 109 | # seed sequence, so we have to rebuild it from scratch
|
105 |
| - rng = np.random.Generator(type(rng.bit_generator)(seed_seq)) |
| 110 | + rng = random_generator_from_state(rng_state) |
106 | 111 | self._msg_pipe = msg_pipe
|
107 | 112 | self._step_method = step_method
|
108 | 113 | self._step_method_is_pickled = step_method_is_pickled
|
@@ -263,8 +268,7 @@ def __init__(
|
263 | 268 | self._shared_point,
|
264 | 269 | draws,
|
265 | 270 | tune,
|
266 |
| - rng, |
267 |
| - rng.bit_generator.seed_seq, |
| 271 | + get_state_from_generator(rng), |
268 | 272 | blas_cores,
|
269 | 273 | ),
|
270 | 274 | )
|
|
0 commit comments