Skip to content

Commit d01e487

Browse files
committed
Add gamma distribution and test gamma and beta dists
1 parent 450477e commit d01e487

File tree

4 files changed

+161
-2
lines changed

4 files changed

+161
-2
lines changed

neuralprocesses/dist/beta.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: Masked):
9090
def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: B.Numeric, *, mask=1):
9191
logz = B.logbeta(self.alpha, self.beta)
9292
logpdf = (self.alpha - 1) * B.log(x) + (self.beta - 1) * B.log(1 - x) - logz
93-
return B.sum(mask * logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :])
93+
logpdf = logpdf * mask
94+
if self.d == 0:
95+
return logpdf
96+
else:
97+
return B.sum(logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :])
9498

9599
def __str__(self):
96100
return f"Beta({self.alpha}, {self.beta})"

neuralprocesses/dist/dist.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def sample(self, state: B.RandomState, dtype: B.DType, *shape):
3131
state (random state, optional): Random state.
3232
tensor: Samples of shape `(*shape, *d)` where typically `d = (*b, c, n)`.
3333
"""
34+
print(type(self), type(state), type(dtype), *shape)
3435
raise NotImplementedError(f"{self} cannot be sampled.")
3536

3637
@_dispatch

neuralprocesses/dist/gamma.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import lab as B
2+
from matrix.shape import broadcast
3+
from plum import parametric
4+
5+
from .. import _dispatch
6+
from ..aggregate import Aggregate
7+
from ..mask import Masked
8+
from .dist import AbstractDistribution, shape_batch
9+
10+
__all__ = ["Gamma"]
11+
12+
13+
@parametric
14+
class Gamma(AbstractDistribution):
15+
"""Gamma distribution.
16+
17+
Args:
18+
k (tensor): Shape parameter.
19+
scale (tensor): Scale parameter.
20+
d (int): Dimensionality of the data.
21+
22+
Attributes:
23+
k (tensor): Shape parameter.
24+
scale (tensor): Scale parameter.
25+
d (int): Dimensionality of the data.
26+
"""
27+
28+
def __init__(self, k, scale, d):
29+
self.k = k
30+
self.scale = scale
31+
self.d = d
32+
33+
@property
34+
def mean(self):
35+
return B.multiply(self.k, self.scale)
36+
37+
@property
38+
def var(self):
39+
return B.multiply(B.multiply(self.k, self.scale), self.scale)
40+
41+
@_dispatch
42+
def sample(
43+
self: "Gamma[Aggregate, Aggregate, Aggregate]",
44+
state: B.RandomState,
45+
dtype: B.DType,
46+
*shape,
47+
):
48+
samples = []
49+
for ki, si, di in zip(self.k, self.scale, self.d):
50+
state, sample = Gamma(ki, si, di).sample(state, dtype, *shape)
51+
samples.append(sample)
52+
return state, Aggregate(*samples)
53+
54+
@_dispatch
55+
def sample(
56+
self: "Gamma[B.Numeric, B.Numeric, B.Int]",
57+
state: B.RandomState,
58+
dtype: B.DType,
59+
*shape,
60+
):
61+
return B.randgamma(state, dtype, *shape, alpha=self.k, scale=self.scale)
62+
63+
@_dispatch
64+
def logpdf(self: "Gamma[Aggregate, Aggregate, Aggregate]", x: Aggregate):
65+
return sum(
66+
[
67+
Gamma(ki, si, di).logpdf(xi)
68+
for ki, si, di, xi in zip(self.k, self.scale, self.d, x)
69+
],
70+
0,
71+
)
72+
73+
@_dispatch
74+
def logpdf(self: "Gamma[B.Numeric, B.Numeric, B.Int]", x: Masked):
75+
x, mask = x.y, x.mask
76+
with B.on_device(self.k):
77+
safe = B.to_active_device(B.one(B.dtype(self)))
78+
# Make inputs safe.
79+
x = mask * x + (1 - mask) * safe
80+
# Run with safe inputs, and filter out the right logpdfs.
81+
return self.logpdf(x, mask=mask)
82+
83+
@_dispatch
84+
def logpdf(self: "Gamma[B.Numeric, B.Numeric, B.Int]", x: B.Numeric, *, mask=1):
85+
logz = B.loggamma(self.k) + self.k * B.log(self.scale)
86+
logpdf = (self.k - 1) * B.log(x) - x / self.scale - logz
87+
logpdf = logpdf * mask
88+
if self.d == 0:
89+
return logpdf
90+
else:
91+
return B.sum(logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :])
92+
93+
def __str__(self):
94+
return f"Gamma({self.k}, {self.scale})"
95+
96+
def __repr__(self):
97+
return f"Gamma({self.k!r}, {self.scale!r})"
98+
99+
100+
@B.dtype.dispatch
101+
def dtype(dist: Gamma):
102+
return B.dtype(dist.k, dist.scale)
103+
104+
105+
@shape_batch.dispatch
106+
def shape_batch(dist: "Gamma[B.Numeric, B.Numeric, B.Int]"):
107+
return B.shape_broadcast(dist.k, dist.scale)[: -dist.d]
108+
109+
110+
@shape_batch.dispatch
111+
def shape_batch(dist: "Gamma[Aggregate, Aggregate, Aggregate]"):
112+
return broadcast(
113+
*(
114+
shape_batch(Gamma(ki, si, di))
115+
for ki, si, di in zip(dist.k, dist.scale, dist.d)
116+
)
117+
)

tests/test_distribution.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import lab as B
2+
import scipy.stats as stats
3+
4+
import torch
5+
from neuralprocesses.dist.beta import Beta
6+
from neuralprocesses.dist.gamma import Gamma
27

38
from .test_architectures import check_prediction, generate_data
4-
from .util import nps # noqa
9+
from .util import approx, nps # noqa
510

611

712
def test_transform_positive(nps):
@@ -42,3 +47,35 @@ def test_transform_bounded(nps):
4247
# Check that predictions and samples satisfy the constraint.
4348
assert B.all(pred.mean > 10) and B.all(pred.mean < 11)
4449
assert B.all(pred.sample() > 10) and B.all(pred.sample() < 11)
50+
51+
52+
def test_beta_correctness():
53+
"""Test the correctness of the beta distribution."""
54+
beta = Beta(B.cast(torch.float64, 0.2), B.cast(torch.float64, 0.8), 0)
55+
beta_ref = stats.beta(0.2, 0.8)
56+
57+
sample = beta.sample()
58+
approx(beta.logpdf(sample), beta_ref.logpdf(sample))
59+
approx(beta.mean, beta_ref.mean())
60+
approx(beta.var, beta_ref.var())
61+
62+
# Test dimensionality argument.
63+
for d in range(4):
64+
beta = Beta(beta.alpha, beta.beta, d)
65+
assert beta.logpdf(beta.sample(1, 2, 3)).shape == (1, 2, 3)[: 3 - d]
66+
67+
68+
def test_gamma():
69+
"""Test the correctness of the gamma distribution."""
70+
gamma = Gamma(B.cast(torch.float64, 2), B.cast(torch.float64, 0.8), 0)
71+
gamma_ref = stats.gamma(2, scale=0.8)
72+
73+
sample = gamma.sample()
74+
approx(gamma.logpdf(sample), gamma_ref.logpdf(sample))
75+
approx(gamma.mean, gamma_ref.mean())
76+
approx(gamma.var, gamma_ref.var())
77+
78+
# Test dimensionality argument.
79+
for d in range(4):
80+
gamma = Gamma(gamma.k, gamma.scale, d)
81+
assert gamma.logpdf(gamma.sample(1, 2, 3)).shape == (1, 2, 3)[: 3 - d]

0 commit comments

Comments
 (0)