Skip to content

Commit 5c63df4

Browse files
committed
Avoid recompiling initial_point and logp functions in sample
Also removes default `model.check_start_vals()`
1 parent 76d7d85 commit 5c63df4

File tree

10 files changed

+210
-130
lines changed

10 files changed

+210
-130
lines changed

pymc/backends/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
7373
from pymc.backends.base import BaseTrace, IBaseTrace
7474
from pymc.backends.ndarray import NDArray
75+
from pymc.blocking import PointType
7576
from pymc.model import Model
7677
from pymc.step_methods.compound import BlockedStep, CompoundStep
7778

@@ -100,11 +101,12 @@ def _init_trace(
100101
trace: BaseTrace | None,
101102
model: Model,
102103
trace_vars: list[TensorVariable] | None = None,
104+
initial_point: PointType | None = None,
103105
) -> BaseTrace:
104106
"""Initialize a trace backend for a chain."""
105107
strace: BaseTrace
106108
if trace is None:
107-
strace = NDArray(model=model, vars=trace_vars)
109+
strace = NDArray(model=model, vars=trace_vars, test_point=initial_point)
108110
elif isinstance(trace, BaseTrace):
109111
if len(trace) > 0:
110112
raise ValueError("Continuation of traces is no longer supported.")
@@ -122,7 +124,7 @@ def init_traces(
122124
chains: int,
123125
expected_length: int,
124126
step: BlockedStep | CompoundStep,
125-
initial_point: Mapping[str, np.ndarray],
127+
initial_point: PointType,
126128
model: Model,
127129
trace_vars: list[TensorVariable] | None = None,
128130
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
@@ -145,6 +147,7 @@ def init_traces(
145147
trace=backend,
146148
model=model,
147149
trace_vars=trace_vars,
150+
initial_point=initial_point,
148151
)
149152
for chain_number in range(chains)
150153
]

pymc/model/core.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
ShapeError,
4949
ShapeWarning,
5050
)
51-
from pymc.initial_point import make_initial_point_fn
51+
from pymc.initial_point import PointType, make_initial_point_fn
5252
from pymc.logprob.basic import transformed_conditional_logp
5353
from pymc.logprob.transforms import Transform
5454
from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values
@@ -174,7 +174,7 @@ def __init__(
174174
casting="no",
175175
compute_grads=True,
176176
model=None,
177-
initial_point=None,
177+
initial_point: PointType | None = None,
178178
ravel_inputs: bool | None = None,
179179
**kwargs,
180180
):
@@ -533,7 +533,7 @@ def logp_dlogp_function(
533533
self,
534534
grad_vars=None,
535535
tempered=False,
536-
initial_point=None,
536+
initial_point: PointType | None = None,
537537
ravel_inputs: bool | None = None,
538538
**kwargs,
539539
):

pymc/sampling/mcmc.py

+108-64
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def instantiate_steppers(
101101
model: Model,
102102
steps: list[Step],
103103
selected_steps: Mapping[type[BlockedStep], list[Any]],
104+
*,
104105
step_kwargs: dict[str, dict] | None = None,
106+
initial_point: PointType | None = None,
105107
) -> Step | list[Step]:
106108
"""Instantiate steppers assigned to the model variables.
107109
@@ -131,13 +133,22 @@ def instantiate_steppers(
131133
step_kwargs = {}
132134

133135
used_keys = set()
134-
for step_class, vars in selected_steps.items():
135-
if vars:
136-
name = getattr(step_class, "name")
137-
args = step_kwargs.get(name, {})
138-
used_keys.add(name)
139-
step = step_class(vars=vars, model=model, **args)
140-
steps.append(step)
136+
if selected_steps:
137+
if initial_point is None:
138+
initial_point = model.initial_point()
139+
140+
for step_class, vars in selected_steps.items():
141+
if vars:
142+
name = getattr(step_class, "name")
143+
kwargs = step_kwargs.get(name, {})
144+
used_keys.add(name)
145+
step = step_class(
146+
vars=vars,
147+
model=model,
148+
initial_point=initial_point,
149+
**kwargs,
150+
)
151+
steps.append(step)
141152

142153
unused_args = set(step_kwargs).difference(used_keys)
143154
if unused_args:
@@ -161,18 +172,22 @@ def assign_step_methods(
161172
model: Model,
162173
step: Step | Sequence[Step] | None = None,
163174
methods: Sequence[type[BlockedStep]] | None = None,
164-
step_kwargs: dict[str, Any] | None = None,
165-
) -> Step | list[Step]:
175+
) -> tuple[list[Step], dict[type[BlockedStep], list[Variable]]]:
166176
"""Assign model variables to appropriate step methods.
167177
168-
Passing a specified model will auto-assign its constituent stochastic
169-
variables to step methods based on the characteristics of the variables.
178+
Passing a specified model will auto-assign its constituent value
179+
variables to step methods based on the characteristics of the respective
180+
random variables, and whether the logp can be differentiated with respect to it.
181+
170182
This function is intended to be called automatically from ``sample()``, but
171183
may be called manually. Each step method passed should have a
172184
``competence()`` method that returns an ordinal competence value
173185
corresponding to the variable passed to it. This value quantifies the
174186
appropriateness of the step method for sampling the variable.
175187
188+
The outputs of this function can then be passed to `instantiate_steppers()`
189+
to initialize the assigned step samplers.
190+
176191
Parameters
177192
----------
178193
model : Model object
@@ -183,24 +198,32 @@ def assign_step_methods(
183198
methods : iterable of step method classes, optional
184199
The set of step methods from which the function may choose. Defaults
185200
to the main step methods provided by PyMC.
186-
step_kwargs : dict, optional
187-
Parameters for the samplers. Keys are the lower case names of
188-
the step method, values a dict of arguments.
189201
190202
Returns
191203
-------
192-
methods : list
193-
List of step methods associated with the model's variables.
204+
provided_steps: list of Step instances
205+
List of user provided instantiated step(s)
206+
assigned_steps: dict of Step class to Variable
207+
Dictionary with automatically selected step classes as keys and associated value variables as values
194208
"""
195-
steps: list[Step] = []
209+
provided_steps: list[Step] = []
196210
assigned_vars: set[Variable] = set()
197211

198212
if step is not None:
199213
if isinstance(step, BlockedStep | CompoundStep):
200-
steps.append(step)
214+
provided_steps = [step]
215+
elif isinstance(step, Sequence):
216+
provided_steps = list(step)
201217
else:
202-
steps.extend(step)
203-
for step in steps:
218+
raise ValueError(f"Step should be a Step or a sequence of Steps, got {step}")
219+
220+
for step in provided_steps:
221+
if not isinstance(step, BlockedStep | CompoundStep):
222+
if issubclass(step, BlockedStep | CompoundStep):
223+
raise ValueError(f"Provided {step} was not initialized")
224+
else:
225+
raise ValueError(f"{step} is not a Step instance")
226+
204227
for var in step.vars:
205228
if var not in model.value_vars:
206229
raise ValueError(
@@ -235,7 +258,7 @@ def assign_step_methods(
235258
)
236259
selected_steps.setdefault(selected, []).append(var)
237260

238-
return instantiate_steppers(model, steps, selected_steps, step_kwargs)
261+
return provided_steps, selected_steps
239262

240263

241264
def _print_step_hierarchy(s: Step, level: int = 0) -> None:
@@ -719,22 +742,23 @@ def joined_blas_limiter():
719742
msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
720743
_log.warning(msg)
721744

722-
auto_nuts_init = True
723-
if step is not None:
724-
if isinstance(step, CompoundStep):
725-
for method in step.methods:
726-
if isinstance(method, NUTS):
727-
auto_nuts_init = False
728-
elif isinstance(step, NUTS):
729-
auto_nuts_init = False
730-
731-
initial_points = None
732-
step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
745+
provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
746+
exclusive_nuts = (
747+
# User provided an instantiated NUTS step, and nothing else is needed
748+
(not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
749+
or
750+
# Only automatically selected NUTS step is needed
751+
(
752+
not provided_steps
753+
and len(selected_steps) == 1
754+
and issubclass(next(iter(selected_steps)), NUTS)
755+
)
756+
)
733757

734758
if nuts_sampler != "pymc":
735-
if not isinstance(step, NUTS):
759+
if not exclusive_nuts:
736760
raise ValueError(
737-
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
761+
"Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
738762
)
739763

740764
with joined_blas_limiter():
@@ -755,13 +779,11 @@ def joined_blas_limiter():
755779
**kwargs,
756780
)
757781

758-
if isinstance(step, list):
759-
step = CompoundStep(step)
760-
elif isinstance(step, NUTS) and auto_nuts_init:
782+
if exclusive_nuts and not provided_steps:
783+
# Special path for NUTS initialization
761784
if "nuts" in kwargs:
762785
nuts_kwargs = kwargs.pop("nuts")
763786
[kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
764-
_log.info("Auto-assigning NUTS sampler...")
765787
with joined_blas_limiter():
766788
initial_points, step = init_nuts(
767789
init=init,
@@ -775,9 +797,8 @@ def joined_blas_limiter():
775797
initvals=initvals,
776798
**kwargs,
777799
)
778-
779-
if initial_points is None:
780-
# Time to draw/evaluate numeric start points for each chain.
800+
else:
801+
# Get initial points
781802
ipfns = make_initial_point_fns_per_chain(
782803
model=model,
783804
overrides=initvals,
@@ -786,11 +807,16 @@ def joined_blas_limiter():
786807
)
787808
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)]
788809

789-
# One final check that shapes and logps at the starting points are okay.
790-
ip: dict[str, np.ndarray]
791-
for ip in initial_points:
792-
model.check_start_vals(ip)
793-
_check_start_shape(model, ip)
810+
# Instantiate automatically selected steps
811+
step = instantiate_steppers(
812+
model,
813+
steps=provided_steps,
814+
selected_steps=selected_steps,
815+
step_kwargs=kwargs,
816+
initial_point=initial_points[0],
817+
)
818+
if isinstance(step, list):
819+
step = CompoundStep(step)
794820

795821
if var_names is not None:
796822
trace_vars = [v for v in model.unobserved_RVs if v.name in var_names]
@@ -806,7 +832,7 @@ def joined_blas_limiter():
806832
expected_length=draws + tune,
807833
step=step,
808834
trace_vars=trace_vars,
809-
initial_point=ip,
835+
initial_point=initial_points[0],
810836
model=model,
811837
)
812838

@@ -954,7 +980,6 @@ def _sample_return(
954980
f"took {t_sampling:.0f} seconds."
955981
)
956982

957-
idata = None
958983
if compute_convergence_checks or return_inferencedata:
959984
ikwargs: dict[str, Any] = {"model": model, "save_warmup": not discard_tuned_samples}
960985
ikwargs.update(idata_kwargs)
@@ -1159,7 +1184,6 @@ def _iter_sample(
11591184
diverging : bool
11601185
Indicates if the draw is divergent. Only available with some samplers.
11611186
"""
1162-
model = modelcontext(model)
11631187
draws = int(draws)
11641188

11651189
if draws < 1:
@@ -1174,8 +1198,6 @@ def _iter_sample(
11741198
if hasattr(step, "reset_tuning"):
11751199
step.reset_tuning()
11761200
for i in range(draws):
1177-
diverging = False
1178-
11791201
if i == 0 and hasattr(step, "iter_count"):
11801202
step.iter_count = 0
11811203
if i == tune:
@@ -1298,6 +1320,7 @@ def _init_jitter(
12981320
seeds: Sequence[int] | np.ndarray,
12991321
jitter: bool,
13001322
jitter_max_retries: int,
1323+
logp_dlogp_func=None,
13011324
) -> list[PointType]:
13021325
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
13031326
@@ -1328,19 +1351,30 @@ def _init_jitter(
13281351
if not jitter:
13291352
return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]
13301353

1354+
model_logp_fn: Callable
1355+
if logp_dlogp_func is None:
1356+
model_logp_fn = model.compile_logp()
1357+
else:
1358+
1359+
def model_logp_fn(ip):
1360+
q, _ = DictToArrayBijection.map(ip)
1361+
return logp_dlogp_func([q], extra_vars={})[0]
1362+
13311363
initial_points = []
13321364
for ipfn, seed in zip(ipfns, seeds):
1333-
rng = np.random.RandomState(seed)
1365+
rng = np.random.default_rng(seed)
13341366
for i in range(jitter_max_retries + 1):
13351367
point = ipfn(seed)
1336-
if i < jitter_max_retries:
1337-
try:
1368+
point_logp = model_logp_fn(point)
1369+
if not np.isfinite(point_logp):
1370+
if i == jitter_max_retries:
1371+
# Print informative message on last attempted point
13381372
model.check_start_vals(point)
1339-
except SamplingError:
1340-
# Retry with a new seed
1341-
seed = rng.randint(2**30, dtype=np.int64)
1342-
else:
1343-
break
1373+
# Retry with a new seed
1374+
seed = rng.integers(2**30, dtype=np.int64)
1375+
else:
1376+
break
1377+
13441378
initial_points.append(point)
13451379
return initial_points
13461380

@@ -1436,10 +1470,12 @@ def init_nuts(
14361470

14371471
_log.info(f"Initializing NUTS using {init}...")
14381472

1439-
cb = [
1440-
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
1441-
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
1442-
]
1473+
cb = []
1474+
if "advi" in init:
1475+
cb = [
1476+
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
1477+
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
1478+
]
14431479

14441480
logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True)
14451481
logp_dlogp_func.trust_input = True
@@ -1449,6 +1485,7 @@ def init_nuts(
14491485
seeds=random_seed_list,
14501486
jitter="jitter" in init,
14511487
jitter_max_retries=jitter_max_retries,
1488+
logp_dlogp_func=logp_dlogp_func,
14521489
)
14531490

14541491
apoints = [DictToArrayBijection.map(point) for point in initial_points]
@@ -1562,7 +1599,14 @@ def init_nuts(
15621599
else:
15631600
raise ValueError(f"Unknown initializer: {init}.")
15641601

1565-
step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs)
1602+
step = pm.NUTS(
1603+
potential=potential,
1604+
model=model,
1605+
rng=random_seed_list[0],
1606+
initial_point=initial_points[0],
1607+
logp_dlogp_func=logp_dlogp_func,
1608+
**kwargs,
1609+
)
15661610

15671611
# Filter deterministics from initial_points
15681612
value_var_names = [var.name for var in model.value_vars]

0 commit comments

Comments
 (0)