Skip to content

Commit c17491d

Browse files
committed
Allow ordering of multi-output OpFromGraph variables in toposort_replace
1 parent bbf9850 commit c17491d

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

pymc/pytensorf.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1127,10 +1127,40 @@ def toposort_replace(
11271127
fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], reverse: bool = False
11281128
) -> None:
11291129
"""Replace multiple variables in place in topological order."""
1130-
toposort = fgraph.toposort()
1130+
fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())}
1131+
_inner_fgraph_toposorts = {} # Cache inner toposorts
1132+
1133+
def _nested_toposort_index(var, fgraph_toposort) -> tuple[int]:
1134+
"""Compute position of variable in fgraph toposort.
1135+
1136+
When a variable is an OpFromGraph output, extend output with the toposort index of the inner graph(s).
1137+
1138+
This allows ordering variables that come from the same OpFromGraph.
1139+
"""
1140+
if not var.owner:
1141+
return (-1,)
1142+
1143+
index = fgraph_toposort[var.owner]
1144+
1145+
# Recurse into OpFromGraphs
1146+
# TODO: Could also recurse into Scans
1147+
if isinstance(var.owner.op, OpFromGraph):
1148+
inner_fgraph = var.owner.op.fgraph
1149+
1150+
if inner_fgraph not in _inner_fgraph_toposorts:
1151+
_inner_fgraph_toposorts[inner_fgraph] = {
1152+
node: i for i, node in enumerate(inner_fgraph.toposort())
1153+
}
1154+
1155+
inner_fgraph_toposort = _inner_fgraph_toposorts[inner_fgraph]
1156+
inner_var = inner_fgraph.outputs[var.owner.outputs.index(var)]
1157+
return (index, *_nested_toposort_index(inner_var, inner_fgraph_toposort))
1158+
else:
1159+
return (index,)
1160+
11311161
sorted_replacements = sorted(
11321162
replacements,
1133-
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
1163+
key=lambda pair: _nested_toposort_index(pair[0], fgraph_toposort),
11341164
reverse=reverse,
11351165
)
11361166
fgraph.replace_all(sorted_replacements, import_missing=True)

tests/test_initial_point.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import pytensor.tensor as pt
1818
import pytest
1919

20+
from pytensor.compile.builders import OpFromGraph
2021
from pytensor.tensor.random.op import RandomVariable
2122

2223
import pymc as pm
2324

24-
from pymc.distributions.distribution import support_point
25+
from pymc.distributions.distribution import _support_point, support_point
2526
from pymc.initial_point import make_initial_point_fn, make_initial_point_fns_per_chain
2627

2728

@@ -192,6 +193,40 @@ def test_string_overrides_work(self):
192193
assert np.isclose(iv["B_log__"], 0)
193194
assert iv["C_log__"] == 0
194195

196+
@pytest.mark.parametrize("reverse_rvs", [False, True])
197+
def test_dependent_initval_from_OFG(self, reverse_rvs):
198+
class MyTestOp(OpFromGraph):
199+
pass
200+
201+
@_support_point.register(MyTestOp)
202+
def my_test_op_support_point(op, out):
203+
out1, out2 = out.owner.outputs
204+
if out is out1:
205+
return out1
206+
else:
207+
return out1 * 4
208+
209+
out1 = pt.zeros(())
210+
out2 = out1 * 2
211+
rv_op = MyTestOp([], [out1, out2])
212+
213+
with pm.Model() as model:
214+
A, B = rv_op()
215+
if reverse_rvs:
216+
model.register_rv(B, "B")
217+
model.register_rv(A, "A")
218+
else:
219+
model.register_rv(A, "A")
220+
model.register_rv(B, "B")
221+
222+
assert model.initial_point() == {"A": 0, "B": 0}
223+
224+
model.set_initval(A, 1)
225+
assert model.initial_point() == {"A": 1, "B": 4}
226+
227+
model.set_initval(B, 3)
228+
assert model.initial_point() == {"A": 1, "B": 3}
229+
195230

196231
class TestSupportPoint:
197232
def test_basic(self):

0 commit comments

Comments
 (0)