Skip to content

Commit e75cd73

Browse files
committed
Include unconditional constant_fold rewrite
1 parent a50bfae commit e75cd73

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

pymc/pytensorf.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from pytensor.tensor.random.op import RandomVariable
4646
from pytensor.tensor.random.type import RandomType
4747
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
48+
from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
4849
from pytensor.tensor.rewriting.shape import ShapeFeature
4950
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
5051
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
@@ -1057,7 +1058,7 @@ def compile_pymc(
10571058

10581059
def constant_fold(
10591060
xs: Sequence[TensorVariable], raise_not_constant: bool = True
1060-
) -> tuple[np.ndarray, ...]:
1061+
) -> tuple[np.ndarray | Variable, ...]:
10611062
"""Use constant folding to get constant values of a graph.
10621063
10631064
Parameters
@@ -1072,8 +1073,12 @@ def constant_fold(
10721073
"""
10731074
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True)
10741075

1075-
# By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
1076-
folded_xs = rewrite_graph(fg).outputs
1076+
# The default rewrite_graph includes a constand_folding that is not always applied.
1077+
# We use an unconditional constant_folding as the last pass to ensure a thorough constant folding.
1078+
rewrite_graph(fg)
1079+
topo_unconditional_constant_folding.apply(fg)
1080+
1081+
folded_xs = fg.outputs
10771082

10781083
if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
10791084
raise NotConstantValueError

tests/test_pytensorf.py

+5
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,11 @@ def test_inputs_preserved(self):
696696
(out_shape,) = constant_fold((out.shape[0],), raise_not_constant=False)
697697
assert out_shape is a
698698

699+
def test_constant_fold_alloc(self):
700+
# By default, Alloc outputs cannot be constant folded
701+
x = pt.alloc(pt.arange(5), 2, 5)
702+
np.testing.assert_allclose(constant_fold([x])[0], np.broadcast_to(np.arange(5), (2, 5)))
703+
699704

700705
def test_replace_vars_in_graphs():
701706
inp = shared(0.0, name="inp")

0 commit comments

Comments
 (0)