|
7 | 7 | from . import _dispatch
|
8 | 8 | from .aggregate import Aggregate, AggregateInput
|
9 | 9 | from .datadims import data_dims
|
10 |
| -from .dist import Beta, Dirac, MultiOutputNormal, SpikesSlab |
| 10 | +from .dist import Beta, Dirac, Gamma, MultiOutputNormal, SpikesSlab |
11 | 11 | from .parallel import Parallel
|
12 | 12 | from .util import register_module, split, split_dimension
|
13 | 13 |
|
|
16 | 16 | "HeterogeneousGaussianLikelihood",
|
17 | 17 | "LowRankGaussianLikelihood",
|
18 | 18 | "DenseGaussianLikelihood",
|
| 19 | + "SpikesBetaLikelihood", |
| 20 | + "BernoulliGammaLikelihood", |
19 | 21 | ]
|
20 | 22 |
|
21 | 23 |
|
@@ -359,7 +361,7 @@ def _dense_var(coder: DenseGaussianLikelihood, xz, z: B.Numeric):
|
359 | 361 |
|
360 | 362 | @register_module
|
361 | 363 | class SpikesBetaLikelihood(AbstractLikelihood):
|
362 |
| - """Gaussian likelihood with heterogeneous noise. |
| 364 | + """Mixture of a beta distribution, a Dirac delta at zero, and a Dirac delta at one. |
363 | 365 |
|
364 | 366 | Args:
|
365 | 367 | epsilon (float, optional): Tolerance for equality checking. Defaults to `1e-6`.
|
@@ -451,3 +453,94 @@ def _spikesbeta(coder: SpikesBetaLikelihood, xz, z: B.Numeric):
|
451 | 453 | logps = z_logps
|
452 | 454 |
|
453 | 455 | return alpha, beta, logp0, logp1, logps, d + 1
|
| 456 | + |
| 457 | + |
| 458 | +@register_module |
| 459 | +class BernoulliGammaLikelihood(AbstractLikelihood): |
| 460 | + """Mixture of a gamma distribution and a Dirac delta at zero. |
| 461 | +
|
| 462 | + Args: |
| 463 | + epsilon (float, optional): Tolerance for equality checking. Defaults to `1e-6`. |
| 464 | +
|
| 465 | + Args: |
| 466 | + epsilon (float): Tolerance for equality checking. |
| 467 | + """ |
| 468 | + |
| 469 | + @_dispatch |
| 470 | + def __init__(self, epsilon: float = 1e-6): |
| 471 | + self.epsilon = epsilon |
| 472 | + |
| 473 | + def __str__(self): |
| 474 | + return repr(self) |
| 475 | + |
| 476 | + def __repr__(self): |
| 477 | + return f"BernoulliGammaLikelihood(epsilon={self.epsilon!r})" |
| 478 | + |
| 479 | + |
| 480 | +@_dispatch |
| 481 | +def code( |
| 482 | + coder: BernoulliGammaLikelihood, |
| 483 | + xz, |
| 484 | + z, |
| 485 | + x, |
| 486 | + *, |
| 487 | + dtype_lik=None, |
| 488 | + **kw_args, |
| 489 | +): |
| 490 | + k, scale, logp0, logps, d = _bernoulligamma(coder, xz, z) |
| 491 | + |
| 492 | + # Cast parameters to the right data type. |
| 493 | + if dtype_lik: |
| 494 | + k = B.cast(dtype_lik, k) |
| 495 | + scale = B.cast(dtype_lik, scale) |
| 496 | + logp0 = B.cast(dtype_lik, logp0) |
| 497 | + logps = B.cast(dtype_lik, logps) |
| 498 | + |
| 499 | + # Create the spikes vector. |
| 500 | + with B.on_device(z): |
| 501 | + dtype = dtype_lik or B.dtype(z) |
| 502 | + spikes = B.stack(B.zero(dtype)) |
| 503 | + |
| 504 | + return xz, SpikesSlab( |
| 505 | + spikes, |
| 506 | + Gamma(k, scale, d), |
| 507 | + B.stack(logp0, logps, axis=-1), |
| 508 | + d, |
| 509 | + epsilon=coder.epsilon, |
| 510 | + ) |
| 511 | + |
| 512 | + |
| 513 | +@_dispatch |
| 514 | +def _bernoulligamma( |
| 515 | + coder: BernoulliGammaLikelihood, |
| 516 | + xz: AggregateInput, |
| 517 | + z: Aggregate, |
| 518 | +): |
| 519 | + ks, scales, logp0s, logpss, ds = zip( |
| 520 | + *[_bernoulligamma(coder, xzi, zi) for (xzi, _), zi in zip(xz, z)] |
| 521 | + ) |
| 522 | + |
| 523 | + # Concatenate into one big distribution. |
| 524 | + k = Aggregate(*ks) |
| 525 | + scale = Aggregate(*scales) |
| 526 | + logp0 = Aggregate(*logp0s) |
| 527 | + logps = Aggregate(*logpss) |
| 528 | + d = Aggregate(*ds) |
| 529 | + |
| 530 | + return k, scale, logp0, logps, d |
| 531 | + |
| 532 | + |
| 533 | +@_dispatch |
| 534 | +def _bernoulligamma(coder: BernoulliGammaLikelihood, xz, z: B.Numeric): |
| 535 | + d = data_dims(xz) |
| 536 | + dim_y = B.shape(z, -d - 1) // 4 |
| 537 | + |
| 538 | + z_k, z_scale, z_logp0, z_logps = split(z, (dim_y, dim_y, dim_y, dim_y), -d - 1) |
| 539 | + |
| 540 | + # Transform into parameters. |
| 541 | + k = 1e-3 + B.softplus(z_k) |
| 542 | + scale = 1e-3 + B.softplus(z_scale) |
| 543 | + logp0 = z_logp0 |
| 544 | + logps = z_logps |
| 545 | + |
| 546 | + return k, scale, logp0, logps, d + 1 |
0 commit comments