Skip to content

Commit bbf9850

Browse files
committed
Do not rely on model variable ordering for initival replacements
It is not part of the API that variables must be registered in topological order
1 parent 36cca5b commit bbf9850

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

pymc/initial_point.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
from pytensor.tensor.variable import TensorVariable
2626

2727
from pymc.logprob.transforms import Transform
28-
from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs
28+
from pymc.pytensorf import (
29+
compile_pymc,
30+
find_rng_nodes,
31+
replace_rng_nodes,
32+
reseed_rngs,
33+
toposort_replace,
34+
)
2935
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name
3036

3137
StartDict = dict[Variable | str, np.ndarray | Variable | str]
@@ -288,8 +294,7 @@ def make_initial_point_expression(
288294
# order, so that later nodes do not reintroduce expressions with earlier
289295
# rvs that would need to once again be replaced by their initial_points
290296
graph = FunctionGraph(outputs=free_rvs_clone, clone=False)
291-
replacements = reversed(list(zip(free_rvs_clone, initial_values_clone)))
292-
graph.replace_all(replacements, import_missing=True)
297+
toposort_replace(graph, tuple(zip(free_rvs_clone, initial_values_clone)), reverse=True)
293298

294299
if not return_transformed:
295300
return graph.outputs

tests/test_initial_point.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,17 @@ def test_make_initial_point_fns_per_chain_checks_kwargs(self):
4747
)
4848
pass
4949

50-
def test_dependent_initvals(self):
50+
@pytest.mark.parametrize("reverse_rvs", [False, True])
51+
def test_dependent_initvals(self, reverse_rvs):
5152
with pm.Model() as pmodel:
5253
L = pm.Uniform("L", 0, 1, initval=0.5)
5354
U = pm.Uniform("U", lower=9, upper=10, initval=9.5)
5455
B1 = pm.Uniform("B1", lower=L, upper=U, initval=5)
5556
B2 = pm.Uniform("B2", lower=L, upper=U, initval=(L + U) / 2)
57+
58+
if reverse_rvs:
59+
pmodel.free_RVs = pmodel.free_RVs[::-1]
60+
5661
ip = pmodel.initial_point(random_seed=0)
5762
assert ip["L_interval__"] == 0
5863
assert ip["U_interval__"] == 0

0 commit comments

Comments
 (0)