@@ -302,7 +302,19 @@ def logp(value, mu, cov):
302
302
)
303
303
304
304
305
- class PrecisionMvNormalRV (SymbolicRandomVariable ):
305
+ class SymbolicMVNormalUsedInternally (SymbolicRandomVariable ):
306
+ """Helper subclass that handles the forwarding / caching of method to `MvNormal` used internally."""
307
+
308
+ def __init__ (self , * args , method : str , ** kwargs ):
309
+ super ().__init__ (* args , ** kwargs )
310
+ self .method = method
311
+
312
+ def rebuild_rv (self , * args , ** kwargs ):
313
+ # rv_op is a classmethod, so it doesn't have access to the instance method
314
+ return self .rv_op (* args , method = self .method , ** kwargs )
315
+
316
+
317
+ class PrecisionMvNormalRV (SymbolicMVNormalUsedInternally ):
306
318
r"""A specialized multivariate normal random variable defined in terms of precision.
307
319
308
320
This class is introduced during specialization logprob rewrites, and not meant to be used directly.
@@ -313,14 +325,17 @@ class PrecisionMvNormalRV(SymbolicRandomVariable):
313
325
_print_name = ("PrecisionMultivariateNormal" , "\\ operatorname{PrecisionMultivariateNormal}" )
314
326
315
327
@classmethod
316
- def rv_op (cls , mean , tau , * , rng = None , size = None ):
328
+ def rv_op (cls , mean , tau , * , method : str = "cholesky" , rng = None , size = None ):
317
329
rng = normalize_rng_param (rng )
318
330
size = normalize_size_param (size )
319
331
cov = pt .linalg .inv (tau )
320
- next_rng , draws = multivariate_normal (mean , cov , size = size , rng = rng ).owner .outputs
332
+ next_rng , draws = multivariate_normal (
333
+ mean , cov , size = size , rng = rng , method = method
334
+ ).owner .outputs
321
335
return cls (
322
336
inputs = [rng , size , mean , tau ],
323
337
outputs = [next_rng , draws ],
338
+ method = method ,
324
339
)(rng , size , mean , tau )
325
340
326
341
@@ -354,7 +369,9 @@ def mv_normal_to_precision_mv_normal(fgraph, node):
354
369
rng , size , mu , cov = node .inputs
355
370
if cov .owner and cov .owner .op == matrix_inverse :
356
371
tau = cov .owner .inputs [0 ]
357
- return PrecisionMvNormalRV .rv_op (mu , tau , size = size , rng = rng ).owner .outputs
372
+ return PrecisionMvNormalRV .rv_op (
373
+ mu , tau , size = size , rng = rng , method = node .op .method
374
+ ).owner .outputs
358
375
return None
359
376
360
377
@@ -365,7 +382,7 @@ def mv_normal_to_precision_mv_normal(fgraph, node):
365
382
)
366
383
367
384
368
- class MvStudentTRV (SymbolicRandomVariable ):
385
+ class MvStudentTRV (SymbolicMVNormalUsedInternally ):
369
386
r"""A specialized multivariate normal random variable defined in terms of precision.
370
387
371
388
This class is introduced during specialization logprob rewrites, and not meant to be used directly.
@@ -376,7 +393,7 @@ class MvStudentTRV(SymbolicRandomVariable):
376
393
_print_name = ("MvStudentT" , "\\ operatorname{MvStudentT}" )
377
394
378
395
@classmethod
379
- def rv_op (cls , nu , mean , scale , * , rng = None , size = None ):
396
+ def rv_op (cls , nu , mean , scale , * , method : str = "cholesky" , rng = None , size = None ):
380
397
nu = pt .as_tensor (nu )
381
398
mean = pt .as_tensor (mean )
382
399
scale = pt .as_tensor (scale )
@@ -387,14 +404,15 @@ def rv_op(cls, nu, mean, scale, *, rng=None, size=None):
387
404
size = implicit_size_from_params (nu , mean , scale , ndims_params = cls .ndims_params )
388
405
389
406
next_rng , mv_draws = multivariate_normal (
390
- mean .zeros_like (), scale , size = size , rng = rng
407
+ mean .zeros_like (), scale , size = size , rng = rng , method = method
391
408
).owner .outputs
392
409
next_rng , chi2_draws = chisquare (nu , size = size , rng = next_rng ).owner .outputs
393
410
draws = mean + (mv_draws / pt .sqrt (chi2_draws / nu )[..., None ])
394
411
395
412
return cls (
396
413
inputs = [rng , size , nu , mean , scale ],
397
414
outputs = [next_rng , draws ],
415
+ method = method ,
398
416
)(rng , size , nu , mean , scale )
399
417
400
418
@@ -1923,12 +1941,12 @@ def logp(value, mu, rowchol, colchol):
1923
1941
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
1924
1942
1925
1943
1926
- class KroneckerNormalRV (SymbolicRandomVariable ):
1944
+ class KroneckerNormalRV (SymbolicMVNormalUsedInternally ):
1927
1945
ndim_supp = 1
1928
1946
_print_name = ("KroneckerNormal" , "\\ operatorname{KroneckerNormal}" )
1929
1947
1930
1948
@classmethod
1931
- def rv_op (cls , mu , sigma , * covs , size = None , rng = None ):
1949
+ def rv_op (cls , mu , sigma , * covs , method : str = "cholesky" , size = None , rng = None ):
1932
1950
mu = pt .as_tensor (mu )
1933
1951
sigma = pt .as_tensor (sigma )
1934
1952
covs = [pt .as_tensor (cov ) for cov in covs ]
@@ -1937,7 +1955,9 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
1937
1955
1938
1956
cov = reduce (pt .linalg .kron , covs )
1939
1957
cov = cov + sigma ** 2 * pt .eye (cov .shape [- 2 ])
1940
- next_rng , draws = multivariate_normal (mean = mu , cov = cov , size = size , rng = rng ).owner .outputs
1958
+ next_rng , draws = multivariate_normal (
1959
+ mean = mu , cov = cov , size = size , rng = rng , method = method
1960
+ ).owner .outputs
1941
1961
1942
1962
covs_sig = "," .join (f"(a{ i } ,b{ i } )" for i in range (len (covs )))
1943
1963
extended_signature = f"[rng],[size],(m),(),{ covs_sig } ->[rng],(m)"
@@ -1946,6 +1966,7 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
1946
1966
inputs = [rng , size , mu , sigma , * covs ],
1947
1967
outputs = [next_rng , draws ],
1948
1968
extended_signature = extended_signature ,
1969
+ method = method ,
1949
1970
)(rng , size , mu , sigma , * covs )
1950
1971
1951
1972
0 commit comments