27
27
from pytensor .tensor import TensorVariable
28
28
from pytensor .tensor .blockwise import Blockwise
29
29
from pytensor .tensor .nlinalg import MatrixInverse
30
+ from pytensor .tensor .random .basic import multivariate_normal
30
31
from pytensor .tensor .random .utils import broadcast_params
31
32
from pytensor .tensor .slinalg import Cholesky
32
33
@@ -1392,6 +1393,11 @@ def test_dirichlet_multinomial_support_point(self, a, n, size, expected):
1392
1393
1393
1394
1394
1395
class TestMvNormalCov (BaseTestDistributionRandom ):
1396
+ def mvnormal_rng_fn (self , size , mean , cov , rng ):
1397
+ if isinstance (size , int ):
1398
+ size = (size ,)
1399
+ return multivariate_normal .rng_fn (rng , mean , cov , size = size )
1400
+
1395
1401
pymc_dist = pm .MvNormal
1396
1402
pymc_dist_params = {
1397
1403
"mu" : np .array ([1.0 , 2.0 ]),
@@ -1407,7 +1413,8 @@ class TestMvNormalCov(BaseTestDistributionRandom):
1407
1413
"mean" : np .array ([1.0 , 2.0 ]),
1408
1414
"cov" : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1409
1415
}
1410
- reference_dist = seeded_numpy_distribution_builder ("multivariate_normal" )
1416
+ reference_dist = lambda self : ft .partial (self .mvnormal_rng_fn , rng = self .get_random_state ()) # noqa: E731
1417
+
1411
1418
checks_to_run = [
1412
1419
"check_pymc_params_match_rv_op" ,
1413
1420
"check_pymc_draws_match_reference" ,
@@ -1531,12 +1538,12 @@ class TestZeroSumNormal:
1531
1538
def assert_zerosum_axes (self , random_samples , axes_to_check , check_zerosum_axes = True ):
1532
1539
if check_zerosum_axes :
1533
1540
for ax in axes_to_check :
1534
- assert np .isclose (random_samples .mean (axis = ax ), 0 ). all ( ), (
1541
+ assert np .allclose (random_samples .mean (axis = ax ), 0 ), (
1535
1542
f"{ ax } is a zerosum_axis but is not summing to 0 across all samples."
1536
1543
)
1537
1544
else :
1538
1545
for ax in axes_to_check :
1539
- assert not np .isclose (random_samples .mean (axis = ax ), 0 ). all ( ), (
1546
+ assert not np .allclose (random_samples .mean (axis = ax ), 0 ), (
1540
1547
f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1541
1548
)
1542
1549
@@ -1775,7 +1782,9 @@ def test_batched_sigma(self):
1775
1782
1776
1783
class TestMvStudentTCov (BaseTestDistributionRandom ):
1777
1784
def mvstudentt_rng_fn (self , size , nu , mu , scale , rng ):
1778
- mv_samples = rng .multivariate_normal (np .zeros_like (mu ), scale , size = size )
1785
+ if isinstance (size , int ):
1786
+ size = (size ,)
1787
+ mv_samples = multivariate_normal .rng_fn (rng , np .zeros_like (mu ), scale , size = size )
1779
1788
chi2_samples = rng .chisquare (nu , size = size )
1780
1789
return (mv_samples / np .sqrt (chi2_samples [:, None ] / nu )) + mu
1781
1790
@@ -2111,9 +2120,11 @@ def check_random_variable_prior(self):
2111
2120
2112
2121
class TestKroneckerNormal (BaseTestDistributionRandom ):
2113
2122
def kronecker_rng_fn (self , size , mu , covs = None , sigma = None , rng = None ):
2114
- cov = pm .math .kronecker (covs [0 ], covs [1 ]).eval ()
2123
+ if isinstance (size , int ):
2124
+ size = (size ,)
2125
+ cov = np .kron (covs [0 ], covs [1 ])
2115
2126
cov += sigma ** 2 * np .identity (cov .shape [0 ])
2116
- return st . multivariate_normal .rvs ( mean = mu , cov = cov , size = size , random_state = rng )
2127
+ return multivariate_normal .rng_fn ( rng , mean = mu , cov = cov , size = size )
2117
2128
2118
2129
pymc_dist = pm .KroneckerNormal
2119
2130
0 commit comments