Skip to content

Commit 51da293

Browse files
committed
Add Bernoulli-gamma likelihood
1 parent d01e487 commit 51da293

File tree

8 files changed

+118
-20
lines changed

8 files changed

+118
-20
lines changed

neuralprocesses/architectures/agnp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def construct_agnp(*args, nps=nps, num_heads=8, **kw_args):
3131
width (int, optional): Widths of all intermediate MLPs. Defaults to 512.
3232
nonlinearity (Callable or str, optional): Nonlinearity. Can also be specified
3333
as a string: `"ReLU"` or `"LeakyReLU"`. Defaults to ReLUs.
34-
likelihood (str, optional): Likelihood. Must be one of `"het"` or `"lowrank"`.
35-
Defaults to `"lowrank"`.
34+
likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`,
35+
`"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
3636
num_basis_functions (int, optional): Number of basis functions for the
3737
low-rank likelihood. Defaults to 512.
3838
dim_lv (int, optional): Dimensionality of the latent variable. Defaults to 0.
3939
lv_likelihood (str, optional): Likelihood of the latent variable. Must be one of
40-
`"het"`, `"dense"`, or `"spikes-beta"`. Defaults to `"het"`.
40+
`"het"` or `"dense"`. Defaults to `"het"`.
4141
transform (str or tuple[float, float]): Bijection applied to the
4242
output of the model. This can help deal with positive of bounded data.
4343
Must be either `"positive"`, `"exp"`, `"softplus"`, or

neuralprocesses/architectures/climate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def construct_climate_convgnp_mlp(
2727
to 128.
2828
lr_deg (float, optional): Resolution of the low-resolution grid. Defaults to
2929
0.75.
30-
likelihood (str, optional): Likelihood. Must be one of `"het"` or `"lowrank".
31-
Defaults to `"lowrank"`.
30+
likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`,
31+
`"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
3232
dtype (dtype, optional): Data type.
3333
"""
3434
mlp_width = 128

neuralprocesses/architectures/convgnp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def construct_convgnp(
150150
margin (float, optional): Margin of the internal discretisation. Defaults to
151151
0.1.
152152
likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`,
153-
or `"spikes-beta"`. Defaults to `"lowrank"`.
153+
`"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
154154
conv_arch (str, optional): Convolutional architecture to use. Must be one of
155155
`"unet[-res][-sep]"` or `"conv[-res][-sep]"`. Defaults to `"unet"`.
156156
unet_channels (tuple[int], optional): Channels of every layer of the UNet.

neuralprocesses/architectures/gnp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ def construct_gnp(
5757
width (int, optional): Widths of all intermediate MLPs. Defaults to 512.
5858
nonlinearity (Callable or str, optional): Nonlinearity. Can also be specified
5959
as a string: `"ReLU"` or `"LeakyReLU"`. Defaults to ReLUs.
60-
likelihood (str, optional): Likelihood. Must be one of `"het"` or `"lowrank"`.
61-
Defaults to `"lowrank"`.
60+
likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`,
61+
`"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
6262
num_basis_functions (int, optional): Number of basis functions for the
6363
low-rank likelihood. Defaults to 512.
6464
dim_lv (int, optional): Dimensionality of the latent variable. Defaults to 0.
6565
lv_likelihood (str, optional): Likelihood of the latent variable. Must be one of
66-
`"het"`, `"dense"`, or `"spikes-beta"`. Defaults to `"het"`.
66+
`"het"` or `"dense"`. Defaults to `"het"`.
6767
transform (str or tuple[float, float]): Bijection applied to the
6868
output of the model. This can help deal with positive of bounded data.
6969
Must be either `"positive"`, `"exp"`, `"softplus"`, or

neuralprocesses/architectures/util.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def construct_likelihood(nps=nps, *, spec, dim_y, num_basis_functions, dtype):
1212
Args:
1313
nps (module): Appropriate backend-specific module.
1414
spec (str, optional): Specification. Must be one of `"het"`, `"lowrank"`,
15-
`"dense"`, or `"spikes-beta"`. Defaults to `"lowrank"`. Must be given as
16-
a keyword argument.
15+
`"dense"`, `"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
16+
Must be given as a keyword argument.
1717
dim_y (int): Dimensionality of the outputs. Must be given as a keyword argument.
1818
num_basis_functions (int): Number of basis functions for the low-rank
1919
likelihood. Must be given as a keyword argument.
@@ -52,6 +52,10 @@ def construct_likelihood(nps=nps, *, spec, dim_y, num_basis_functions, dtype):
5252
num_channels = (2 + 3) * dim_y # Alpha, beta, and three log-probabilities
5353
selector = nps.SelectFromChannels(dim_y, dim_y, dim_y, dim_y, dim_y)
5454
lik = nps.SpikesBetaLikelihood()
55+
elif spec == "bernoulli-gamma":
56+
num_channels = (2 + 2) * dim_y # Shape, scale, and two log-probabilities
57+
selector = nps.SelectFromChannels(dim_y, dim_y, dim_y, dim_y)
58+
lik = nps.BernoulliGammaLikelihood()
5559

5660
else:
5761
raise ValueError(f'Incorrect likelihood specification "{spec}".')

neuralprocesses/dist/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .beta import *
22
from .dirac import *
33
from .dist import *
4+
from .gamma import *
45
from .geom import *
56
from .normal import *
67
from .spikeslab import *

neuralprocesses/likelihood.py

+95-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from . import _dispatch
88
from .aggregate import Aggregate, AggregateInput
99
from .datadims import data_dims
10-
from .dist import Beta, Dirac, MultiOutputNormal, SpikesSlab
10+
from .dist import Beta, Dirac, Gamma, MultiOutputNormal, SpikesSlab
1111
from .parallel import Parallel
1212
from .util import register_module, split, split_dimension
1313

@@ -16,6 +16,8 @@
1616
"HeterogeneousGaussianLikelihood",
1717
"LowRankGaussianLikelihood",
1818
"DenseGaussianLikelihood",
19+
"SpikesBetaLikelihood",
20+
"BernoulliGammaLikelihood",
1921
]
2022

2123

@@ -359,7 +361,7 @@ def _dense_var(coder: DenseGaussianLikelihood, xz, z: B.Numeric):
359361

360362
@register_module
361363
class SpikesBetaLikelihood(AbstractLikelihood):
362-
"""Gaussian likelihood with heterogeneous noise.
364+
"""Mixture of a beta distribution, a Dirac delta at zero, and a Dirac delta at one.
363365
364366
Args:
365367
epsilon (float, optional): Tolerance for equality checking. Defaults to `1e-6`.
@@ -451,3 +453,94 @@ def _spikesbeta(coder: SpikesBetaLikelihood, xz, z: B.Numeric):
451453
logps = z_logps
452454

453455
return alpha, beta, logp0, logp1, logps, d + 1
456+
457+
458+
@register_module
459+
class BernoulliGammaLikelihood(AbstractLikelihood):
460+
"""Mixture of a gamma distribution and a Dirac delta at zero.
461+
462+
Args:
463+
epsilon (float, optional): Tolerance for equality checking. Defaults to `1e-6`.
464+
465+
Args:
466+
epsilon (float): Tolerance for equality checking.
467+
"""
468+
469+
@_dispatch
470+
def __init__(self, epsilon: float = 1e-6):
471+
self.epsilon = epsilon
472+
473+
def __str__(self):
474+
return repr(self)
475+
476+
def __repr__(self):
477+
return f"BernoulliGammaLikelihood(epsilon={self.epsilon!r})"
478+
479+
480+
@_dispatch
481+
def code(
482+
coder: BernoulliGammaLikelihood,
483+
xz,
484+
z,
485+
x,
486+
*,
487+
dtype_lik=None,
488+
**kw_args,
489+
):
490+
k, scale, logp0, logps, d = _bernoulligamma(coder, xz, z)
491+
492+
# Cast parameters to the right data type.
493+
if dtype_lik:
494+
k = B.cast(dtype_lik, k)
495+
scale = B.cast(dtype_lik, scale)
496+
logp0 = B.cast(dtype_lik, logp0)
497+
logps = B.cast(dtype_lik, logps)
498+
499+
# Create the spikes vector.
500+
with B.on_device(z):
501+
dtype = dtype_lik or B.dtype(z)
502+
spikes = B.stack(B.zero(dtype))
503+
504+
return xz, SpikesSlab(
505+
spikes,
506+
Gamma(k, scale, d),
507+
B.stack(logp0, logps, axis=-1),
508+
d,
509+
epsilon=coder.epsilon,
510+
)
511+
512+
513+
@_dispatch
514+
def _bernoulligamma(
515+
coder: BernoulliGammaLikelihood,
516+
xz: AggregateInput,
517+
z: Aggregate,
518+
):
519+
ks, scales, logp0s, logpss, ds = zip(
520+
*[_bernoulligamma(coder, xzi, zi) for (xzi, _), zi in zip(xz, z)]
521+
)
522+
523+
# Concatenate into one big distribution.
524+
k = Aggregate(*ks)
525+
scale = Aggregate(*scales)
526+
logp0 = Aggregate(*logp0s)
527+
logps = Aggregate(*logpss)
528+
d = Aggregate(*ds)
529+
530+
return k, scale, logp0, logps, d
531+
532+
533+
@_dispatch
534+
def _bernoulligamma(coder: BernoulliGammaLikelihood, xz, z: B.Numeric):
535+
d = data_dims(xz)
536+
dim_y = B.shape(z, -d - 1) // 4
537+
538+
z_k, z_scale, z_logp0, z_logps = split(z, (dim_y, dim_y, dim_y, dim_y), -d - 1)
539+
540+
# Transform into parameters.
541+
k = 1e-3 + B.softplus(z_k)
542+
scale = 1e-3 + B.softplus(z_scale)
543+
logp0 = z_logp0
544+
logps = z_logps
545+
546+
return k, scale, logp0, logps, d + 1

tests/test_architectures.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def product_kw_args(config, **kw_args):
6666
},
6767
dim_x=[1, 2],
6868
dim_y=[1, 2],
69-
likelihood=["het", "lowrank", "spikes-beta"],
69+
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
7070
)
7171
# NP:
7272
+ product_kw_args(
@@ -79,7 +79,7 @@ def product_kw_args(config, **kw_args):
7979
},
8080
dim_x=[1, 2],
8181
dim_y=[1, 2],
82-
likelihood=["het", "lowrank", "spikes-beta"],
82+
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
8383
lv_likelihood=["het", "dense"],
8484
)
8585
# ACNP:
@@ -94,7 +94,7 @@ def product_kw_args(config, **kw_args):
9494
},
9595
dim_x=[1, 2],
9696
dim_y=[1, 2],
97-
likelihood=["het", "lowrank", "spikes-beta"],
97+
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
9898
)
9999
# ANP:
100100
+ product_kw_args(
@@ -108,7 +108,7 @@ def product_kw_args(config, **kw_args):
108108
},
109109
dim_x=[1, 2],
110110
dim_y=[1, 2],
111-
likelihood=["het", "lowrank", "spikes-beta"],
111+
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
112112
lv_likelihood=["het", "dense"],
113113
)
114114
# ConvCNP and ConvGNP:
@@ -122,7 +122,7 @@ def product_kw_args(config, **kw_args):
122122
},
123123
dim_x=[1, 2],
124124
dim_y=[1, 2],
125-
likelihood=["het", "lowrank", "spikes-beta"],
125+
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
126126
encoder_scales_learnable=[True, False],
127127
decoder_scale_learnable=[True, False],
128128
)
@@ -138,7 +138,7 @@ def product_kw_args(config, **kw_args):
138138
},
139139
dim_x=[1, 2],
140140
dim_y=[1, 2],
141-
likelihood=["het", "lowrank", "spikes-beta"],
141+
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
142142
lv_likelihood=["het", "lowrank"],
143143
)
144144
)
@@ -219,7 +219,7 @@ def construct_model():
219219

220220
def sample():
221221
if "likelihood" in config:
222-
binary = config["likelihood"] == "spikes-beta"
222+
binary = config["likelihood"] in {"spikes-beta", "bernoulli-gamma"}
223223
else:
224224
binary = False
225225
return generate_data(

0 commit comments

Comments
 (0)