45
45
from pytensor .tensor .random .op import RandomVariable
46
46
from pytensor .tensor .random .type import RandomType
47
47
from pytensor .tensor .random .var import RandomGeneratorSharedVariable
48
+ from pytensor .tensor .rewriting .basic import topo_unconditional_constant_folding
48
49
from pytensor .tensor .rewriting .shape import ShapeFeature
49
50
from pytensor .tensor .sharedvar import SharedVariable , TensorSharedVariable
50
51
from pytensor .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
@@ -1057,7 +1058,7 @@ def compile_pymc(
1057
1058
1058
1059
def constant_fold (
1059
1060
xs : Sequence [TensorVariable ], raise_not_constant : bool = True
1060
- ) -> tuple [np .ndarray , ...]:
1061
+ ) -> tuple [np .ndarray | Variable , ...]:
1061
1062
"""Use constant folding to get constant values of a graph.
1062
1063
1063
1064
Parameters
@@ -1072,8 +1073,12 @@ def constant_fold(
1072
1073
"""
1073
1074
fg = FunctionGraph (outputs = xs , features = [ShapeFeature ()], copy_inputs = False , clone = True )
1074
1075
1075
- # By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
1076
- folded_xs = rewrite_graph (fg ).outputs
1076
+ # The default rewrite_graph includes a constand_folding that is not always applied.
1077
+ # We use an unconditional constant_folding as the last pass to ensure a thorough constant folding.
1078
+ rewrite_graph (fg )
1079
+ topo_unconditional_constant_folding .apply (fg )
1080
+
1081
+ folded_xs = fg .outputs
1077
1082
1078
1083
if raise_not_constant and not all (isinstance (folded_x , Constant ) for folded_x in folded_xs ):
1079
1084
raise NotConstantValueError
0 commit comments