Skip to content

Commit 7c369c8

Browse files
committed
Allow passing compile_kwargs to step inner functions
1 parent 5c63df4 commit 7c369c8

File tree

6 files changed

+55
-16
lines changed

6 files changed

+55
-16
lines changed

pymc/sampling/mcmc.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def instantiate_steppers(
104104
*,
105105
step_kwargs: dict[str, dict] | None = None,
106106
initial_point: PointType | None = None,
107+
compile_kwargs: dict | None = None,
107108
) -> Step | list[Step]:
108109
"""Instantiate steppers assigned to the model variables.
109110
@@ -146,6 +147,7 @@ def instantiate_steppers(
146147
vars=vars,
147148
model=model,
148149
initial_point=initial_point,
150+
compile_kwargs=compile_kwargs,
149151
**kwargs,
150152
)
151153
steps.append(step)
@@ -434,6 +436,7 @@ def sample(
434436
callback=None,
435437
mp_ctx=None,
436438
blas_cores: int | None | Literal["auto"] = "auto",
439+
compile_kwargs: dict | None = None,
437440
**kwargs,
438441
) -> InferenceData: ...
439442

@@ -466,6 +469,7 @@ def sample(
466469
mp_ctx=None,
467470
model: Model | None = None,
468471
blas_cores: int | None | Literal["auto"] = "auto",
472+
compile_kwargs: dict | None = None,
469473
**kwargs,
470474
) -> MultiTrace: ...
471475

@@ -497,6 +501,7 @@ def sample(
497501
mp_ctx=None,
498502
blas_cores: int | None | Literal["auto"] = "auto",
499503
model: Model | None = None,
504+
compile_kwargs: dict | None = None,
500505
**kwargs,
501506
) -> InferenceData | MultiTrace:
502507
r"""Draw samples from the posterior using the given step methods.
@@ -598,6 +603,9 @@ def sample(
598603
See multiprocessing documentation for details.
599604
model : Model (optional if in ``with`` context)
600605
Model to sample from. The model needs to have free random variables.
606+
compile_kwargs: dict, optional
607+
Dictionary with keyword argument to pass to the functions compiled by the step methods.
608+
601609
602610
Returns
603611
-------
@@ -795,6 +803,7 @@ def joined_blas_limiter():
795803
jitter_max_retries=jitter_max_retries,
796804
tune=tune,
797805
initvals=initvals,
806+
compile_kwargs=compile_kwargs,
798807
**kwargs,
799808
)
800809
else:
@@ -814,6 +823,7 @@ def joined_blas_limiter():
814823
selected_steps=selected_steps,
815824
step_kwargs=kwargs,
816825
initial_point=initial_points[0],
826+
compile_kwargs=compile_kwargs,
817827
)
818828
if isinstance(step, list):
819829
step = CompoundStep(step)
@@ -1390,6 +1400,7 @@ def init_nuts(
13901400
jitter_max_retries: int = 10,
13911401
tune: int | None = None,
13921402
initvals: StartDict | Sequence[StartDict | None] | None = None,
1403+
compile_kwargs: dict | None = None,
13931404
**kwargs,
13941405
) -> tuple[Sequence[PointType], NUTS]:
13951406
"""Set up the mass matrix initialization for NUTS.
@@ -1466,6 +1477,9 @@ def init_nuts(
14661477
if init == "auto":
14671478
init = "jitter+adapt_diag"
14681479

1480+
if compile_kwargs is None:
1481+
compile_kwargs = {}
1482+
14691483
random_seed_list = _get_seeds_per_chain(random_seed, chains)
14701484

14711485
_log.info(f"Initializing NUTS using {init}...")
@@ -1477,7 +1491,7 @@ def init_nuts(
14771491
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
14781492
]
14791493

1480-
logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True)
1494+
logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs)
14811495
logp_dlogp_func.trust_input = True
14821496
initial_points = _init_jitter(
14831497
model,

pymc/step_methods/arraystep.py

+4
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,20 @@ def __init__(
182182
logp_dlogp_func=None,
183183
rng: RandomGenerator = None,
184184
initial_point: PointType | None = None,
185+
compile_kwargs: dict | None = None,
185186
**pytensor_kwargs,
186187
):
187188
model = modelcontext(model)
188189

189190
if logp_dlogp_func is None:
191+
if compile_kwargs is None:
192+
compile_kwargs = {}
190193
logp_dlogp_func = model.logp_dlogp_function(
191194
vars,
192195
dtype=dtype,
193196
ravel_inputs=True,
194197
initial_point=initial_point,
198+
**compile_kwargs,
195199
**pytensor_kwargs,
196200
)
197201
logp_dlogp_func.trust_input = True

pymc/step_methods/metropolis.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(
162162
mode=None,
163163
rng=None,
164164
initial_point: PointType | None = None,
165+
compile_kwargs: dict | None = None,
165166
blocked: bool = False,
166167
):
167168
"""Create an instance of a Metropolis stepper.
@@ -254,7 +255,7 @@ def __init__(
254255
self.mode = mode
255256

256257
shared = pm.make_shared_replacements(initial_point, vars, model)
257-
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared)
258+
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs)
258259
super().__init__(vars, shared, blocked=blocked, rng=rng)
259260

260261
def reset_tuning(self):
@@ -432,6 +433,7 @@ def __init__(
432433
model=None,
433434
rng=None,
434435
initial_point: PointType | None = None,
436+
compile_kwargs: dict | None = None,
435437
blocked: bool = True,
436438
):
437439
model = pm.modelcontext(model)
@@ -447,7 +449,9 @@ def __init__(
447449
if not all(v.dtype in pm.discrete_types for v in vars):
448450
raise ValueError("All variables must be Bernoulli for BinaryMetropolis")
449451

450-
super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng)
452+
if compile_kwargs is None:
453+
compile_kwargs = {}
454+
super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng)
451455

452456
def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
453457
logp = args[0]
@@ -554,6 +558,7 @@ def __init__(
554558
model=None,
555559
rng=None,
556560
initial_point: PointType | None = None,
561+
compile_kwargs: dict | None = None,
557562
blocked: bool = True,
558563
):
559564
model = pm.modelcontext(model)
@@ -582,7 +587,10 @@ def __init__(
582587
if not all(v.dtype in pm.discrete_types for v in vars):
583588
raise ValueError("All variables must be binary for BinaryGibbsMetropolis")
584589

585-
super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng)
590+
if compile_kwargs is None:
591+
compile_kwargs = {}
592+
593+
super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng)
586594

587595
def reset_tuning(self):
588596
# There are no tuning parameters in this step method.
@@ -672,6 +680,7 @@ def __init__(
672680
model=None,
673681
rng: RandomGenerator = None,
674682
initial_point: PointType | None = None,
683+
compile_kwargs: dict | None = None,
675684
blocked: bool = True,
676685
):
677686
model = pm.modelcontext(model)
@@ -728,7 +737,9 @@ def __init__(
728737
# that indicates whether a draw was done in a tuning phase.
729738
self.tune = True
730739

731-
super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng)
740+
if compile_kwargs is None:
741+
compile_kwargs = {}
742+
super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng)
732743

733744
def reset_tuning(self):
734745
# There are no tuning parameters in this step method.
@@ -904,6 +915,7 @@ def __init__(
904915
mode=None,
905916
rng=None,
906917
initial_point: PointType | None = None,
918+
compile_kwargs: dict | None = None,
907919
blocked: bool = True,
908920
):
909921
model = pm.modelcontext(model)
@@ -939,7 +951,7 @@ def __init__(
939951
self.mode = mode
940952

941953
shared = pm.make_shared_replacements(initial_point, vars, model)
942-
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared)
954+
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs)
943955
super().__init__(vars, shared, blocked=blocked, rng=rng)
944956

945957
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
@@ -1073,6 +1085,7 @@ def __init__(
10731085
tune_drop_fraction: float = 0.9,
10741086
model=None,
10751087
initial_point: PointType | None = None,
1088+
compile_kwargs: dict | None = None,
10761089
mode=None,
10771090
rng=None,
10781091
blocked: bool = True,
@@ -1122,7 +1135,7 @@ def __init__(
11221135
self.mode = mode
11231136

11241137
shared = pm.make_shared_replacements(initial_point, vars, model)
1125-
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared)
1138+
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs)
11261139
super().__init__(vars, shared, blocked=blocked, rng=rng)
11271140

11281141
def reset_tuning(self):
@@ -1213,6 +1226,7 @@ def delta_logp(
12131226
logp: pt.TensorVariable,
12141227
vars: list[pt.TensorVariable],
12151228
shared: dict[pt.TensorVariable, pt.sharedvar.TensorSharedVariable],
1229+
compile_kwargs: dict | None,
12161230
) -> pytensor.compile.Function:
12171231
[logp0], inarray0 = join_nonshared_inputs(
12181232
point=point, outputs=[logp], inputs=vars, shared_inputs=shared
@@ -1225,6 +1239,8 @@ def delta_logp(
12251239
# Replace any potential duplicated RNG nodes
12261240
(logp1,) = replace_rng_nodes((logp1,))
12271241

1228-
f = compile_pymc([inarray1, inarray0], logp1 - logp0)
1242+
if compile_kwargs is None:
1243+
compile_kwargs = {}
1244+
f = compile_pymc([inarray1, inarray0], logp1 - logp0, **compile_kwargs)
12291245
f.trust_input = True
12301246
return f

pymc/step_methods/slicer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
iter_limit=np.inf,
8787
rng=None,
8888
initial_point: PointType | None = None,
89+
compile_kwargs: dict | None = None,
8990
blocked: bool = False, # Could be true since tuning is independent across dims?
9091
):
9192
model = modelcontext(model)
@@ -106,7 +107,9 @@ def __init__(
106107
[logp], raveled_inp = join_nonshared_inputs(
107108
point=initial_point, outputs=[model.logp()], inputs=vars, shared_inputs=shared
108109
)
109-
self.logp = compile_pymc([raveled_inp], logp)
110+
if compile_kwargs is None:
111+
compile_kwargs = {}
112+
self.logp = compile_pymc([raveled_inp], logp, **compile_kwargs)
110113
self.logp.trust_input = True
111114

112115
super().__init__(vars, shared, blocked=blocked, rng=rng)

tests/helpers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import pytensor
2727

2828
from numpy.testing import assert_array_less
29+
from pytensor.compile.mode import Mode
2930
from pytensor.gradient import verify_grad as at_verify_grad
3031

3132
import pymc as pm
@@ -198,10 +199,11 @@ def continuous_steps(self, step, step_kwargs):
198199
c1 = pm.HalfNormal("c1")
199200
c2 = pm.HalfNormal("c2")
200201

201-
# Test methods can handle initial_point
202+
# Test methods can handle initial_point and compile_kwargs
202203
step_kwargs.setdefault(
203204
"initial_point", {"c1_log__": np.array(0.5), "c2_log__": np.array(0.9)}
204205
)
206+
step_kwargs.setdefault("compile_kwargs", {"mode": Mode(linker="py", optimizer=None)})
205207
with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
206208
assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars
207209
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(

tests/step_methods/test_metropolis.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import pytensor
2323
import pytest
2424

25+
from pytensor.compile.mode import Mode
26+
2527
import pymc as pm
2628

2729
from pymc.step_methods.metropolis import (
@@ -368,18 +370,16 @@ def test_discrete_steps(self, step):
368370
d1 = pm.Bernoulli("d1", p=0.5)
369371
d2 = pm.Bernoulli("d2", p=0.5)
370372

371-
# Test it can take initial_point as a kwarg
373+
# Test it can take initial_point, and compile_kwargs as a kwarg
372374
step_kwargs = {
373375
"initial_point": {
374376
"d1": np.array(0, dtype="int64"),
375377
"d2": np.array(1, dtype="int64"),
376378
},
379+
"compile_kwargs": {"mode": Mode(linker="py", optimizer=None)},
377380
}
378-
with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
379-
assert [m.rvs_to_values[d1]] == step([d1]).vars
380-
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(
381-
step([d1, d2]).vars
382-
)
381+
assert [m.rvs_to_values[d1]] == step([d1]).vars
382+
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(step([d1, d2]).vars)
383383

384384
@pytest.mark.parametrize(
385385
"step, step_kwargs", [(Metropolis, {}), (DEMetropolis, {}), (DEMetropolisZ, {})]

0 commit comments

Comments
 (0)