@@ -101,7 +101,9 @@ def instantiate_steppers(
101
101
model : Model ,
102
102
steps : list [Step ],
103
103
selected_steps : Mapping [type [BlockedStep ], list [Any ]],
104
+ * ,
104
105
step_kwargs : dict [str , dict ] | None = None ,
106
+ initial_point : PointType | None = None ,
105
107
) -> Step | list [Step ]:
106
108
"""Instantiate steppers assigned to the model variables.
107
109
@@ -131,13 +133,22 @@ def instantiate_steppers(
131
133
step_kwargs = {}
132
134
133
135
used_keys = set ()
134
- for step_class , vars in selected_steps .items ():
135
- if vars :
136
- name = getattr (step_class , "name" )
137
- args = step_kwargs .get (name , {})
138
- used_keys .add (name )
139
- step = step_class (vars = vars , model = model , ** args )
140
- steps .append (step )
136
+ if selected_steps :
137
+ if initial_point is None :
138
+ initial_point = model .initial_point ()
139
+
140
+ for step_class , vars in selected_steps .items ():
141
+ if vars :
142
+ name = getattr (step_class , "name" )
143
+ kwargs = step_kwargs .get (name , {})
144
+ used_keys .add (name )
145
+ step = step_class (
146
+ vars = vars ,
147
+ model = model ,
148
+ initial_point = initial_point ,
149
+ ** kwargs ,
150
+ )
151
+ steps .append (step )
141
152
142
153
unused_args = set (step_kwargs ).difference (used_keys )
143
154
if unused_args :
@@ -161,18 +172,22 @@ def assign_step_methods(
161
172
model : Model ,
162
173
step : Step | Sequence [Step ] | None = None ,
163
174
methods : Sequence [type [BlockedStep ]] | None = None ,
164
- step_kwargs : dict [str , Any ] | None = None ,
165
- ) -> Step | list [Step ]:
175
+ ) -> tuple [list [Step ], dict [type [BlockedStep ], list [Variable ]]]:
166
176
"""Assign model variables to appropriate step methods.
167
177
168
- Passing a specified model will auto-assign its constituent stochastic
169
- variables to step methods based on the characteristics of the variables.
178
+ Passing a specified model will auto-assign its constituent value
179
+ variables to step methods based on the characteristics of the respective
180
+ random variables, and whether the logp can be differentiated with respect to it.
181
+
170
182
This function is intended to be called automatically from ``sample()``, but
171
183
may be called manually. Each step method passed should have a
172
184
``competence()`` method that returns an ordinal competence value
173
185
corresponding to the variable passed to it. This value quantifies the
174
186
appropriateness of the step method for sampling the variable.
175
187
188
+ The outputs of this function can then be passed to `instantiate_steppers()`
189
+ to initialize the assigned step samplers.
190
+
176
191
Parameters
177
192
----------
178
193
model : Model object
@@ -183,24 +198,32 @@ def assign_step_methods(
183
198
methods : iterable of step method classes, optional
184
199
The set of step methods from which the function may choose. Defaults
185
200
to the main step methods provided by PyMC.
186
- step_kwargs : dict, optional
187
- Parameters for the samplers. Keys are the lower case names of
188
- the step method, values a dict of arguments.
189
201
190
202
Returns
191
203
-------
192
- methods : list
193
- List of step methods associated with the model's variables.
204
+ provided_steps: list of Step instances
205
+ List of user provided instantiated step(s)
206
+ assigned_steps: dict of Step class to Variable
207
+ Dictionary with automatically selected step classes as keys and associated value variables as values
194
208
"""
195
- steps : list [Step ] = []
209
+ provided_steps : list [Step ] = []
196
210
assigned_vars : set [Variable ] = set ()
197
211
198
212
if step is not None :
199
213
if isinstance (step , BlockedStep | CompoundStep ):
200
- steps .append (step )
214
+ provided_steps = [step ]
215
+ elif isinstance (step , Sequence ):
216
+ provided_steps = list (step )
201
217
else :
202
- steps .extend (step )
203
- for step in steps :
218
+ raise ValueError (f"Step should be a Step or a sequence of Steps, got { step } " )
219
+
220
+ for step in provided_steps :
221
+ if not isinstance (step , BlockedStep | CompoundStep ):
222
+ if issubclass (step , BlockedStep | CompoundStep ):
223
+ raise ValueError (f"Provided { step } was not initialized" )
224
+ else :
225
+ raise ValueError (f"{ step } is not a Step instance" )
226
+
204
227
for var in step .vars :
205
228
if var not in model .value_vars :
206
229
raise ValueError (
@@ -235,7 +258,7 @@ def assign_step_methods(
235
258
)
236
259
selected_steps .setdefault (selected , []).append (var )
237
260
238
- return instantiate_steppers ( model , steps , selected_steps , step_kwargs )
261
+ return provided_steps , selected_steps
239
262
240
263
241
264
def _print_step_hierarchy (s : Step , level : int = 0 ) -> None :
@@ -719,22 +742,23 @@ def joined_blas_limiter():
719
742
msg = f"Only { draws } samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
720
743
_log .warning (msg )
721
744
722
- auto_nuts_init = True
723
- if step is not None :
724
- if isinstance (step , CompoundStep ):
725
- for method in step .methods :
726
- if isinstance (method , NUTS ):
727
- auto_nuts_init = False
728
- elif isinstance (step , NUTS ):
729
- auto_nuts_init = False
730
-
731
- initial_points = None
732
- step = assign_step_methods (model , step , methods = pm .STEP_METHODS , step_kwargs = kwargs )
745
+ provided_steps , selected_steps = assign_step_methods (model , step , methods = pm .STEP_METHODS )
746
+ exclusive_nuts = (
747
+ # User provided an instantiated NUTS step, and nothing else is needed
748
+ (not selected_steps and len (provided_steps ) == 1 and isinstance (provided_steps [0 ], NUTS ))
749
+ or
750
+ # Only automatically selected NUTS step is needed
751
+ (
752
+ not provided_steps
753
+ and len (selected_steps ) == 1
754
+ and issubclass (next (iter (selected_steps )), NUTS )
755
+ )
756
+ )
733
757
734
758
if nuts_sampler != "pymc" :
735
- if not isinstance ( step , NUTS ) :
759
+ if not exclusive_nuts :
736
760
raise ValueError (
737
- "Model can not be sampled with NUTS alone. Your model is probably not continuous ."
761
+ "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability ."
738
762
)
739
763
740
764
with joined_blas_limiter ():
@@ -755,13 +779,11 @@ def joined_blas_limiter():
755
779
** kwargs ,
756
780
)
757
781
758
- if isinstance (step , list ):
759
- step = CompoundStep (step )
760
- elif isinstance (step , NUTS ) and auto_nuts_init :
782
+ if exclusive_nuts and not provided_steps :
783
+ # Special path for NUTS initialization
761
784
if "nuts" in kwargs :
762
785
nuts_kwargs = kwargs .pop ("nuts" )
763
786
[kwargs .setdefault (k , v ) for k , v in nuts_kwargs .items ()]
764
- _log .info ("Auto-assigning NUTS sampler..." )
765
787
with joined_blas_limiter ():
766
788
initial_points , step = init_nuts (
767
789
init = init ,
@@ -775,9 +797,8 @@ def joined_blas_limiter():
775
797
initvals = initvals ,
776
798
** kwargs ,
777
799
)
778
-
779
- if initial_points is None :
780
- # Time to draw/evaluate numeric start points for each chain.
800
+ else :
801
+ # Get initial points
781
802
ipfns = make_initial_point_fns_per_chain (
782
803
model = model ,
783
804
overrides = initvals ,
@@ -786,11 +807,16 @@ def joined_blas_limiter():
786
807
)
787
808
initial_points = [ipfn (seed ) for ipfn , seed in zip (ipfns , random_seed_list )]
788
809
789
- # One final check that shapes and logps at the starting points are okay.
790
- ip : dict [str , np .ndarray ]
791
- for ip in initial_points :
792
- model .check_start_vals (ip )
793
- _check_start_shape (model , ip )
810
+ # Instantiate automatically selected steps
811
+ step = instantiate_steppers (
812
+ model ,
813
+ steps = provided_steps ,
814
+ selected_steps = selected_steps ,
815
+ step_kwargs = kwargs ,
816
+ initial_point = initial_points [0 ],
817
+ )
818
+ if isinstance (step , list ):
819
+ step = CompoundStep (step )
794
820
795
821
if var_names is not None :
796
822
trace_vars = [v for v in model .unobserved_RVs if v .name in var_names ]
@@ -806,7 +832,7 @@ def joined_blas_limiter():
806
832
expected_length = draws + tune ,
807
833
step = step ,
808
834
trace_vars = trace_vars ,
809
- initial_point = ip ,
835
+ initial_point = initial_points [ 0 ] ,
810
836
model = model ,
811
837
)
812
838
@@ -954,7 +980,6 @@ def _sample_return(
954
980
f"took { t_sampling :.0f} seconds."
955
981
)
956
982
957
- idata = None
958
983
if compute_convergence_checks or return_inferencedata :
959
984
ikwargs : dict [str , Any ] = {"model" : model , "save_warmup" : not discard_tuned_samples }
960
985
ikwargs .update (idata_kwargs )
@@ -1159,7 +1184,6 @@ def _iter_sample(
1159
1184
diverging : bool
1160
1185
Indicates if the draw is divergent. Only available with some samplers.
1161
1186
"""
1162
- model = modelcontext (model )
1163
1187
draws = int (draws )
1164
1188
1165
1189
if draws < 1 :
@@ -1174,8 +1198,6 @@ def _iter_sample(
1174
1198
if hasattr (step , "reset_tuning" ):
1175
1199
step .reset_tuning ()
1176
1200
for i in range (draws ):
1177
- diverging = False
1178
-
1179
1201
if i == 0 and hasattr (step , "iter_count" ):
1180
1202
step .iter_count = 0
1181
1203
if i == tune :
@@ -1298,6 +1320,7 @@ def _init_jitter(
1298
1320
seeds : Sequence [int ] | np .ndarray ,
1299
1321
jitter : bool ,
1300
1322
jitter_max_retries : int ,
1323
+ logp_dlogp_func = None ,
1301
1324
) -> list [PointType ]:
1302
1325
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
1303
1326
@@ -1328,19 +1351,30 @@ def _init_jitter(
1328
1351
if not jitter :
1329
1352
return [ipfn (seed ) for ipfn , seed in zip (ipfns , seeds )]
1330
1353
1354
+ model_logp_fn : Callable
1355
+ if logp_dlogp_func is None :
1356
+ model_logp_fn = model .compile_logp ()
1357
+ else :
1358
+
1359
+ def model_logp_fn (ip ):
1360
+ q , _ = DictToArrayBijection .map (ip )
1361
+ return logp_dlogp_func ([q ], extra_vars = {})[0 ]
1362
+
1331
1363
initial_points = []
1332
1364
for ipfn , seed in zip (ipfns , seeds ):
1333
- rng = np .random .RandomState (seed )
1365
+ rng = np .random .default_rng (seed )
1334
1366
for i in range (jitter_max_retries + 1 ):
1335
1367
point = ipfn (seed )
1336
- if i < jitter_max_retries :
1337
- try :
1368
+ point_logp = model_logp_fn (point )
1369
+ if not np .isfinite (point_logp ):
1370
+ if i == jitter_max_retries :
1371
+ # Print informative message on last attempted point
1338
1372
model .check_start_vals (point )
1339
- except SamplingError :
1340
- # Retry with a new seed
1341
- seed = rng . randint ( 2 ** 30 , dtype = np . int64 )
1342
- else :
1343
- break
1373
+ # Retry with a new seed
1374
+ seed = rng . integers ( 2 ** 30 , dtype = np . int64 )
1375
+ else :
1376
+ break
1377
+
1344
1378
initial_points .append (point )
1345
1379
return initial_points
1346
1380
@@ -1436,10 +1470,12 @@ def init_nuts(
1436
1470
1437
1471
_log .info (f"Initializing NUTS using { init } ..." )
1438
1472
1439
- cb = [
1440
- pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = "absolute" ),
1441
- pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = "relative" ),
1442
- ]
1473
+ cb = []
1474
+ if "advi" in init :
1475
+ cb = [
1476
+ pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = "absolute" ),
1477
+ pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = "relative" ),
1478
+ ]
1443
1479
1444
1480
logp_dlogp_func = model .logp_dlogp_function (ravel_inputs = True )
1445
1481
logp_dlogp_func .trust_input = True
@@ -1449,6 +1485,7 @@ def init_nuts(
1449
1485
seeds = random_seed_list ,
1450
1486
jitter = "jitter" in init ,
1451
1487
jitter_max_retries = jitter_max_retries ,
1488
+ logp_dlogp_func = logp_dlogp_func ,
1452
1489
)
1453
1490
1454
1491
apoints = [DictToArrayBijection .map (point ) for point in initial_points ]
@@ -1562,7 +1599,14 @@ def init_nuts(
1562
1599
else :
1563
1600
raise ValueError (f"Unknown initializer: { init } ." )
1564
1601
1565
- step = pm .NUTS (potential = potential , model = model , rng = random_seed_list [0 ], ** kwargs )
1602
+ step = pm .NUTS (
1603
+ potential = potential ,
1604
+ model = model ,
1605
+ rng = random_seed_list [0 ],
1606
+ initial_point = initial_points [0 ],
1607
+ logp_dlogp_func = logp_dlogp_func ,
1608
+ ** kwargs ,
1609
+ )
1566
1610
1567
1611
# Filter deterministics from initial_points
1568
1612
value_var_names = [var .name for var in model .value_vars ]
0 commit comments