Skip to content

Commit ae67043

Browse files
committed
Allow forwarding of MvNormal method to SymbolicRandomVariables
1 parent 205c193 commit ae67043

File tree

3 files changed

+60
-11
lines changed

3 files changed

+60
-11
lines changed

pymc/distributions/distribution.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,14 @@ def batch_ndim(self, node: Apply) -> int:
383383
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs)
384384
return out_ndim - self.ndim_supp
385385

386+
def rebuild_rv(self, *args, **kwargs):
387+
"""Rebuild the RandomVariable with new inputs."""
388+
if not hasattr(self, "rv_op"):
389+
raise NotImplementedError(
390+
f"SymbolicRandomVariable {self} without `rv_op` method cannot be rebuilt automatically."
391+
)
392+
return self.rv_op(*args, **kwargs)
393+
386394

387395
@_change_dist_size.register(SymbolicRandomVariable)
388396
def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> TensorVariable:
@@ -403,7 +411,7 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) ->
403411
if expand and not rv_size_is_none(size):
404412
new_size = tuple(new_size) + tuple(size)
405413

406-
return op.rv_op(*params, size=new_size)
414+
return op.rebuild_rv(*params, size=new_size)
407415

408416

409417
class Distribution(metaclass=DistributionMeta):

pymc/distributions/multivariate.py

+31-10
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,19 @@ def logp(value, mu, cov):
302302
)
303303

304304

305-
class PrecisionMvNormalRV(SymbolicRandomVariable):
305+
class SymbolicMVNormalUsedInternally(SymbolicRandomVariable):
306+
"""Helper subclass that handles the forwarding / caching of method to `MvNormal` used internally."""
307+
308+
def __init__(self, *args, method: str, **kwargs):
309+
super().__init__(*args, **kwargs)
310+
self.method = method
311+
312+
def rebuild_rv(self, *args, **kwargs):
313+
# rv_op is a classmethod, so it doesn't have access to the instance method
314+
return self.rv_op(*args, method=self.method, **kwargs)
315+
316+
317+
class PrecisionMvNormalRV(SymbolicMVNormalUsedInternally):
306318
r"""A specialized multivariate normal random variable defined in terms of precision.
307319
308320
This class is introduced during specialization logprob rewrites, and not meant to be used directly.
@@ -313,14 +325,17 @@ class PrecisionMvNormalRV(SymbolicRandomVariable):
313325
_print_name = ("PrecisionMultivariateNormal", "\\operatorname{PrecisionMultivariateNormal}")
314326

315327
@classmethod
316-
def rv_op(cls, mean, tau, *, rng=None, size=None):
328+
def rv_op(cls, mean, tau, *, method: str = "cholesky", rng=None, size=None):
317329
rng = normalize_rng_param(rng)
318330
size = normalize_size_param(size)
319331
cov = pt.linalg.inv(tau)
320-
next_rng, draws = multivariate_normal(mean, cov, size=size, rng=rng).owner.outputs
332+
next_rng, draws = multivariate_normal(
333+
mean, cov, size=size, rng=rng, method=method
334+
).owner.outputs
321335
return cls(
322336
inputs=[rng, size, mean, tau],
323337
outputs=[next_rng, draws],
338+
method=method,
324339
)(rng, size, mean, tau)
325340

326341

@@ -354,7 +369,9 @@ def mv_normal_to_precision_mv_normal(fgraph, node):
354369
rng, size, mu, cov = node.inputs
355370
if cov.owner and cov.owner.op == matrix_inverse:
356371
tau = cov.owner.inputs[0]
357-
return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs
372+
return PrecisionMvNormalRV.rv_op(
373+
mu, tau, size=size, rng=rng, method=node.op.method
374+
).owner.outputs
358375
return None
359376

360377

@@ -365,7 +382,7 @@ def mv_normal_to_precision_mv_normal(fgraph, node):
365382
)
366383

367384

368-
class MvStudentTRV(SymbolicRandomVariable):
385+
class MvStudentTRV(SymbolicMVNormalUsedInternally):
369386
r"""A specialized multivariate normal random variable defined in terms of precision.
370387
371388
This class is introduced during specialization logprob rewrites, and not meant to be used directly.
@@ -376,7 +393,7 @@ class MvStudentTRV(SymbolicRandomVariable):
376393
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")
377394

378395
@classmethod
379-
def rv_op(cls, nu, mean, scale, *, rng=None, size=None):
396+
def rv_op(cls, nu, mean, scale, *, method: str = "cholesky", rng=None, size=None):
380397
nu = pt.as_tensor(nu)
381398
mean = pt.as_tensor(mean)
382399
scale = pt.as_tensor(scale)
@@ -387,14 +404,15 @@ def rv_op(cls, nu, mean, scale, *, rng=None, size=None):
387404
size = implicit_size_from_params(nu, mean, scale, ndims_params=cls.ndims_params)
388405

389406
next_rng, mv_draws = multivariate_normal(
390-
mean.zeros_like(), scale, size=size, rng=rng
407+
mean.zeros_like(), scale, size=size, rng=rng, method=method
391408
).owner.outputs
392409
next_rng, chi2_draws = chisquare(nu, size=size, rng=next_rng).owner.outputs
393410
draws = mean + (mv_draws / pt.sqrt(chi2_draws / nu)[..., None])
394411

395412
return cls(
396413
inputs=[rng, size, nu, mean, scale],
397414
outputs=[next_rng, draws],
415+
method=method,
398416
)(rng, size, nu, mean, scale)
399417

400418

@@ -1923,12 +1941,12 @@ def logp(value, mu, rowchol, colchol):
19231941
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
19241942

19251943

1926-
class KroneckerNormalRV(SymbolicRandomVariable):
1944+
class KroneckerNormalRV(SymbolicMVNormalUsedInternally):
19271945
ndim_supp = 1
19281946
_print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}")
19291947

19301948
@classmethod
1931-
def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
1949+
def rv_op(cls, mu, sigma, *covs, method: str = "cholesky", size=None, rng=None):
19321950
mu = pt.as_tensor(mu)
19331951
sigma = pt.as_tensor(sigma)
19341952
covs = [pt.as_tensor(cov) for cov in covs]
@@ -1937,7 +1955,9 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
19371955

19381956
cov = reduce(pt.linalg.kron, covs)
19391957
cov = cov + sigma**2 * pt.eye(cov.shape[-2])
1940-
next_rng, draws = multivariate_normal(mean=mu, cov=cov, size=size, rng=rng).owner.outputs
1958+
next_rng, draws = multivariate_normal(
1959+
mean=mu, cov=cov, size=size, rng=rng, method=method
1960+
).owner.outputs
19411961

19421962
covs_sig = ",".join(f"(a{i},b{i})" for i in range(len(covs)))
19431963
extended_signature = f"[rng],[size],(m),(),{covs_sig}->[rng],(m)"
@@ -1946,6 +1966,7 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
19461966
inputs=[rng, size, mu, sigma, *covs],
19471967
outputs=[next_rng, draws],
19481968
extended_signature=extended_signature,
1969+
method=method,
19491970
)(rng, size, mu, sigma, *covs)
19501971

19511972

tests/distributions/test_multivariate.py

+20
Original file line numberDiff line numberDiff line change
@@ -2469,6 +2469,26 @@ def test_mvstudentt_mu_convenience():
24692469
np.testing.assert_allclose(mu.eval(), np.ones((10, 2, 3)))
24702470

24712471

2472+
def test_mvstudentt_method():
2473+
def all_svd_method(fgraph):
2474+
found_one = False
2475+
for node in fgraph.toposort():
2476+
if isinstance(node.op, pm.MvNormal):
2477+
found_one = True
2478+
if not node.op.method == "svd":
2479+
return False
2480+
return found_one # We want to fail if there were no MvNormal nodes
2481+
2482+
x = pm.MvStudentT.dist(nu=4, scale=np.eye(3), method="svd")
2483+
assert x.type.shape == (3,)
2484+
assert all_svd_method(x.owner.op.fgraph)
2485+
2486+
# Changing the size should preserve the method
2487+
resized_x = change_dist_size(x, (2,))
2488+
assert resized_x.type.shape == (2, 3)
2489+
assert all_svd_method(resized_x.owner.op.fgraph)
2490+
2491+
24722492
def test_precision_mv_normal_optimization():
24732493
rng = np.random.default_rng(sum(map(ord, "be precise")))
24742494

0 commit comments

Comments
 (0)