Skip to content

Commit 62335ac

Browse files
committed
Fix bug when reusing jax logp for initial point generation
1 parent 355b475 commit 62335ac

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

pymc/sampling/jax.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,10 @@ def eval_logp_initial_point(point: dict[str, np.ndarray]) -> jax.Array:
240240
Wraps jaxified logp function to accept a dict of
241241
{model_variable: np.array} key:value pairs.
242242
"""
243-
return logp_fn(point.values())
243+
# Because logp_fn is not jitted, we need to convert inputs to jax arrays,
244+
# or some methods that are only available for jax arrays will fail
245+
# such as x.at[indices].set(y)
246+
return logp_fn([jax.numpy.asarray(v) for v in point.values()])
244247

245248
initial_points = _init_jitter(
246249
model,

tests/sampling/test_jax.py

+21
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,27 @@ def test_get_batched_jittered_initial_points():
352352
assert np.all(ips[0][0] != ips[0][1])
353353

354354

355+
def test_get_batched_jittered_initial_points_set_subtensor():
356+
"""Regression bug for issue described in
357+
https://discourse.pymc.io/t/attributeerror-numpy-ndarray-object-has-no-attribute-at-when-sampling-lkj-cholesky-covariance-priors-for-multivariate-normal-models-example-with-numpyro-or-blackjax/16598/3
358+
359+
Which was caused by passing numpy arrays to a non-jitted logp function
360+
"""
361+
with pm.Model() as model:
362+
# Set operation will use `x.at[1].set(100)` which is only available in JAX
363+
x = pm.Normal("x", mu=[-100, -100])
364+
mu_y = x[1].set(100)
365+
y = pm.Normal("y", mu=mu_y)
366+
367+
logp_fn = get_jaxified_logp(model)
368+
[x_ips, y_ips] = _get_batched_jittered_initial_points(
369+
model, chains=3, initvals=None, logp_fn=logp_fn, jitter=True, random_seed=0
370+
)
371+
assert np.all(x_ips < -10)
372+
assert np.all(y_ips[..., 0] < -10)
373+
assert np.all(y_ips[..., 1] > 10)
374+
375+
355376
@pytest.mark.parametrize(
356377
"sampler",
357378
[

0 commit comments

Comments
 (0)