Skip to content

Commit 3b6e351

Browse files
authored
Probability distributions guide update (#7671)
1 parent 268e13b commit 3b6e351

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

docs/source/guides/Probability_Distributions.rst

+24-13
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,32 @@ A variable requires at least a ``name`` argument, and zero or more model paramet
2929

3030
Probability distributions are all subclasses of ``Distribution``, which in turn has two major subclasses: ``Discrete`` and ``Continuous``. In terms of data types, a ``Continuous`` random variable is given whichever floating point type is defined by ``pytensor.config.floatX``, while ``Discrete`` variables are given ``int16`` types when ``pytensor.config.floatX`` is ``float32``, and ``int64`` otherwise.
3131

32-
All distributions in ``pm.distributions`` will have two important methods: ``random()`` and ``logp()`` with the following signatures:
32+
All distributions in ``pm.distributions`` are associated with two key functions:
33+
34+
1. ``logp(dist, value)`` - Calculates log-probability at given value
35+
2. ``draw(dist, size=...)`` - Generates random samples
36+
37+
For example, with a normal distribution:
3338

3439
::
3540

36-
class SomeDistribution(Continuous):
41+
with pm.Model():
42+
x = pm.Normal('x', mu=0, sigma=1)
43+
44+
# Calculate log-probability
45+
log_prob = pm.logp(x, 0.5)
46+
47+
# Generate samples
48+
samples = pm.draw(x, size=100)
3749

38-
def random(self, point=None, size=None):
39-
...
40-
return random_samples
50+
Custom distributions using ``CustomDist`` should provide logp via the ``dist`` parameter:
51+
52+
::
4153

42-
def logp(self, value):
43-
...
44-
return total_log_prob
54+
def custom_logp(value, mu):
55+
return -0.5 * (value - mu)**2
4556

46-
PyMC expects the ``logp()`` method to return a log-probability evaluated at the passed ``value`` argument. This method is used internally by all of the inference methods to calculate the model log-probability that is used for fitting models. The ``random()`` method is used to simulate values from the variable, and is used internally for posterior predictive checks.
57+
custom_dist = pm.CustomDist('custom', dist=custom_logp, mu=0)
4758

4859

4960
Custom distributions
@@ -58,7 +69,7 @@ An exponential survival function, where :math:`c=0` denotes failure (or non-surv
5869
f(c, t) = \left\{ \begin{array}{l} \exp(-\lambda t), \text{if c=1} \\
5970
\lambda \exp(-\lambda t), \text{if c=0} \end{array} \right.
6071
61-
Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as a keyword argument to the ``DensityDist`` function, which creates an instance of a PyMC distribution with the custom function as its log-probability.
72+
Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as a keyword argument to the ``CustomDist`` function, which creates an instance of a PyMC distribution with the custom function as its log-probability.
6273

6374
For the exponential survival function, this is:
6475

@@ -67,7 +78,7 @@ For the exponential survival function, this is:
6778
def logp(value, t, lam):
6879
return (value * log(lam) - lam * t).sum()
6980

70-
exp_surv = pm.DensityDist('exp_surv', t, lam, logp=logp, observed=failure)
81+
exp_surv = pm.CustomDist('exp_surv', dist=logp, t=t, lam=lam, observed=failure)
7182

7283
Similarly, if a random number generator is required, a function returning random numbers corresponding to the probability distribution can be passed as the ``random`` argument.
7384

@@ -98,10 +109,10 @@ This allows for probabilities to be calculated and random numbers to be drawn.
98109

99110
::
100111

101-
>>> y.logp(4).eval()
112+
>>> pm.logp(y, 4).eval()
102113
array(-1.5843639373779297, dtype=float32)
103114

104-
>>> y.random(size=3)
115+
>>> pm.draw(y, size=3)
105116
array([5, 4, 3])
106117

107118

0 commit comments

Comments
 (0)