Skip to content

Commit 30ee47d

Browse files
committed
Do not mutate Scan inner graph when deriving logprob
1 parent e75cd73 commit 30ee47d

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

pymc/logprob/scan.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,9 @@ def find_measurable_scans(fgraph, node):
463463
# We must also replace any lingering references to the old RVs by the new measurable RVS
464464
# For example if we had measurable out1 = exp(normal()) and out2 = out1 - x
465465
# We need to replace references of original out1 by the new MeasurableExp(normal())
466-
inner_outs = node.op.inner_outputs.copy()
466+
clone_fgraph = node.op.fgraph.clone()
467+
inner_inps = clone_fgraph.inputs
468+
inner_outs = clone_fgraph.outputs
467469
inner_rvs_replacements = []
468470
for idx, new_inner_rv in zip(valued_output_idxs, inner_rvs, strict=True):
469471
old_inner_rv = inner_outs[idx]
@@ -474,8 +476,7 @@ def find_measurable_scans(fgraph, node):
474476
clone=False,
475477
)
476478
toposort_replace(temp_fgraph, inner_rvs_replacements)
477-
inner_outs = temp_fgraph.outputs[: len(inner_outs)]
478-
op = MeasurableScan(node.op.inner_inputs, inner_outs, node.op.info, mode=copy(node.op.mode))
479+
op = MeasurableScan(inner_inps, inner_outs, node.op.info, mode=copy(node.op.mode))
479480
new_outs = op.make_node(*node.inputs).outputs
480481

481482
old_outs = node.outputs

tests/logprob/test_scan.py

+22
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,25 @@ def test_scan_multiple_output_types():
550550
test_value, [a + b for a, b in itertools.pairwise([1, 1, *test_value[:-1]])]
551551
),
552552
)
553+
554+
555+
def test_generative_graph_unchanged():
556+
# Regression test where creating the IR would overwrite the original Scan inner fgraph
557+
558+
def step(eps_tm1):
559+
x = pt.random.normal(0, eps_tm1)
560+
eps_t = x - 0
561+
return (x, eps_t), {x.owner.inputs[0]: x.owner.outputs[0]}
562+
563+
[xs, _], update = pytensor.scan(step, outputs_info=[None, pt.ones(())], n_steps=5)
564+
565+
before = xs.dprint(file="str")
566+
567+
xs_value = np.ones(5)
568+
expected_logp = stats.norm.logpdf(xs_value, 0, 1)
569+
for i in range(2):
570+
xs_logp = logp(xs, xs_value)
571+
np.testing.assert_allclose(xs_logp.eval(), expected_logp)
572+
573+
after = xs.dprint(file="str")
574+
assert before == after

0 commit comments

Comments
 (0)