@@ -162,6 +162,7 @@ def __init__(
162
162
mode = None ,
163
163
rng = None ,
164
164
initial_point : PointType | None = None ,
165
+ compile_kwargs : dict | None = None ,
165
166
blocked : bool = False ,
166
167
):
167
168
"""Create an instance of a Metropolis stepper.
@@ -254,7 +255,7 @@ def __init__(
254
255
self .mode = mode
255
256
256
257
shared = pm .make_shared_replacements (initial_point , vars , model )
257
- self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared )
258
+ self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared , compile_kwargs )
258
259
super ().__init__ (vars , shared , blocked = blocked , rng = rng )
259
260
260
261
def reset_tuning (self ):
@@ -432,6 +433,7 @@ def __init__(
432
433
model = None ,
433
434
rng = None ,
434
435
initial_point : PointType | None = None ,
436
+ compile_kwargs : dict | None = None ,
435
437
blocked : bool = True ,
436
438
):
437
439
model = pm .modelcontext (model )
@@ -447,7 +449,9 @@ def __init__(
447
449
if not all (v .dtype in pm .discrete_types for v in vars ):
448
450
raise ValueError ("All variables must be Bernoulli for BinaryMetropolis" )
449
451
450
- super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
452
+ if compile_kwargs is None :
453
+ compile_kwargs = {}
454
+ super ().__init__ (vars , [model .compile_logp (** compile_kwargs )], blocked = blocked , rng = rng )
451
455
452
456
def astep (self , apoint : RaveledVars , * args ) -> tuple [RaveledVars , StatsType ]:
453
457
logp = args [0 ]
@@ -554,6 +558,7 @@ def __init__(
554
558
model = None ,
555
559
rng = None ,
556
560
initial_point : PointType | None = None ,
561
+ compile_kwargs : dict | None = None ,
557
562
blocked : bool = True ,
558
563
):
559
564
model = pm .modelcontext (model )
@@ -582,7 +587,10 @@ def __init__(
582
587
if not all (v .dtype in pm .discrete_types for v in vars ):
583
588
raise ValueError ("All variables must be binary for BinaryGibbsMetropolis" )
584
589
585
- super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
590
+ if compile_kwargs is None :
591
+ compile_kwargs = {}
592
+
593
+ super ().__init__ (vars , [model .compile_logp (** compile_kwargs )], blocked = blocked , rng = rng )
586
594
587
595
def reset_tuning (self ):
588
596
# There are no tuning parameters in this step method.
@@ -672,6 +680,7 @@ def __init__(
672
680
model = None ,
673
681
rng : RandomGenerator = None ,
674
682
initial_point : PointType | None = None ,
683
+ compile_kwargs : dict | None = None ,
675
684
blocked : bool = True ,
676
685
):
677
686
model = pm .modelcontext (model )
@@ -728,7 +737,9 @@ def __init__(
728
737
# that indicates whether a draw was done in a tuning phase.
729
738
self .tune = True
730
739
731
- super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
740
+ if compile_kwargs is None :
741
+ compile_kwargs = {}
742
+ super ().__init__ (vars , [model .compile_logp (** compile_kwargs )], blocked = blocked , rng = rng )
732
743
733
744
def reset_tuning (self ):
734
745
# There are no tuning parameters in this step method.
@@ -904,6 +915,7 @@ def __init__(
904
915
mode = None ,
905
916
rng = None ,
906
917
initial_point : PointType | None = None ,
918
+ compile_kwargs : dict | None = None ,
907
919
blocked : bool = True ,
908
920
):
909
921
model = pm .modelcontext (model )
@@ -939,7 +951,7 @@ def __init__(
939
951
self .mode = mode
940
952
941
953
shared = pm .make_shared_replacements (initial_point , vars , model )
942
- self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared )
954
+ self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared , compile_kwargs )
943
955
super ().__init__ (vars , shared , blocked = blocked , rng = rng )
944
956
945
957
def astep (self , q0 : RaveledVars ) -> tuple [RaveledVars , StatsType ]:
@@ -1073,6 +1085,7 @@ def __init__(
1073
1085
tune_drop_fraction : float = 0.9 ,
1074
1086
model = None ,
1075
1087
initial_point : PointType | None = None ,
1088
+ compile_kwargs : dict | None = None ,
1076
1089
mode = None ,
1077
1090
rng = None ,
1078
1091
blocked : bool = True ,
@@ -1122,7 +1135,7 @@ def __init__(
1122
1135
self .mode = mode
1123
1136
1124
1137
shared = pm .make_shared_replacements (initial_point , vars , model )
1125
- self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared )
1138
+ self .delta_logp = delta_logp (initial_point , model .logp (), vars , shared , compile_kwargs )
1126
1139
super ().__init__ (vars , shared , blocked = blocked , rng = rng )
1127
1140
1128
1141
def reset_tuning (self ):
@@ -1213,6 +1226,7 @@ def delta_logp(
1213
1226
logp : pt .TensorVariable ,
1214
1227
vars : list [pt .TensorVariable ],
1215
1228
shared : dict [pt .TensorVariable , pt .sharedvar .TensorSharedVariable ],
1229
+ compile_kwargs : dict | None ,
1216
1230
) -> pytensor .compile .Function :
1217
1231
[logp0 ], inarray0 = join_nonshared_inputs (
1218
1232
point = point , outputs = [logp ], inputs = vars , shared_inputs = shared
@@ -1225,6 +1239,8 @@ def delta_logp(
1225
1239
# Replace any potential duplicated RNG nodes
1226
1240
(logp1 ,) = replace_rng_nodes ((logp1 ,))
1227
1241
1228
- f = compile_pymc ([inarray1 , inarray0 ], logp1 - logp0 )
1242
+ if compile_kwargs is None :
1243
+ compile_kwargs = {}
1244
+ f = compile_pymc ([inarray1 , inarray0 ], logp1 - logp0 , ** compile_kwargs )
1229
1245
f .trust_input = True
1230
1246
return f
0 commit comments