Skip to content

Commit 205c193

Browse files
committed
Make more distributions symbolic so they work in different backends
1 parent 630bbaf commit 205c193

File tree

8 files changed

+216
-189
lines changed

8 files changed

+216
-189
lines changed

.github/workflows/tests.yml

+11-7
Original file line numberDiff line numberDiff line change
@@ -281,16 +281,20 @@ jobs:
281281
name: ${{ matrix.os }} ${{ matrix.floatx }}
282282
fail_ci_if_error: false
283283

284-
external_samplers:
284+
alternative_backends:
285285
needs: changes
286286
if: ${{ needs.changes.outputs.changes == 'true' }}
287287
strategy:
288288
matrix:
289289
os: [ubuntu-20.04]
290290
floatx: [float64]
291-
python-version: ["3.13"]
291+
python-version: ["3.12"]
292292
test-subset:
293-
- tests/sampling/test_jax.py tests/sampling/test_mcmc_external.py
293+
- |
294+
tests/distributions/test_random_alternative_backends.py
295+
tests/sampling/test_jax.py
296+
tests/sampling/test_mcmc_external.py
297+
294298
fail-fast: false
295299
runs-on: ${{ matrix.os }}
296300
env:
@@ -305,7 +309,7 @@ jobs:
305309
persist-credentials: false
306310
- uses: mamba-org/setup-micromamba@v2
307311
with:
308-
environment-file: conda-envs/environment-jax.yml
312+
environment-file: conda-envs/environment-alternative-backends.yml
309313
create-args: >-
310314
python=${{matrix.python-version}}
311315
environment-name: pymc-test
@@ -324,7 +328,7 @@ jobs:
324328
with:
325329
token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads
326330
env_vars: TEST_SUBSET
327-
name: JAX tests - ${{ matrix.os }} ${{ matrix.floatx }}
331+
name: Alternative backend tests - ${{ matrix.os }} ${{ matrix.floatx }}
328332
fail_ci_if_error: false
329333

330334
float32:
@@ -378,13 +382,13 @@ jobs:
378382
all_tests:
379383
if: ${{ always() }}
380384
runs-on: ubuntu-latest
381-
needs: [ changes, ubuntu, windows, macos, external_samplers, float32 ]
385+
needs: [ changes, ubuntu, windows, macos, alternative_backends, float32 ]
382386
steps:
383387
- name: Check build matrix status
384388
if: ${{ needs.changes.outputs.changes == 'true' &&
385389
( needs.ubuntu.result != 'success' ||
386390
needs.windows.result != 'success' ||
387391
needs.macos.result != 'success' ||
388-
needs.external_samplers.result != 'success' ||
392+
needs.alternative_backends.result != 'success' ||
389393
needs.float32.result != 'success' ) }}
390394
run: exit 1

conda-envs/environment-jax.yml conda-envs/environment-alternative-backends.yml

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ dependencies:
1010
- cachetools>=4.2.1
1111
- cloudpickle
1212
- zarr>=2.5.0,<3
13+
- numba
14+
- nutpie >= 0.13.4
1315
# Jaxlib version must not be greater than jax version!
1416
- blackjax>=1.2.2
1517
- jax>=0.4.28

pymc/distributions/continuous.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -2595,23 +2595,27 @@ def dist(cls, nu, **kwargs):
25952595
return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs)
25962596

25972597

2598-
class WeibullBetaRV(RandomVariable):
2598+
class WeibullBetaRV(SymbolicRandomVariable):
25992599
name = "weibull"
2600-
signature = "(),()->()"
2601-
dtype = "floatX"
2600+
extended_signature = "[rng],[size],(),()->[rng],()"
26022601
_print_name = ("Weibull", "\\operatorname{Weibull}")
26032602

2604-
def __call__(self, alpha, beta, size=None, **kwargs):
2605-
return super().__call__(alpha, beta, size=size, **kwargs)
2606-
26072603
@classmethod
2608-
def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray:
2609-
if size is None:
2610-
size = np.broadcast_shapes(alpha.shape, beta.shape)
2611-
return np.asarray(beta * rng.weibull(alpha, size=size))
2604+
def rv_op(cls, alpha, beta, *, rng=None, size=None) -> np.ndarray:
2605+
alpha = pt.as_tensor(alpha)
2606+
beta = pt.as_tensor(beta)
2607+
rng = normalize_rng_param(rng)
2608+
size = normalize_size_param(size)
26122609

2610+
if rv_size_is_none(size):
2611+
size = implicit_size_from_params(alpha, beta, ndims_params=cls.ndims_params)
26132612

2614-
weibull_beta = WeibullBetaRV()
2613+
next_rng, raw_weibull = pt.random.weibull(alpha, size=size, rng=rng).owner.outputs
2614+
draws = beta * raw_weibull
2615+
return cls(
2616+
inputs=[rng, size, alpha, beta],
2617+
outputs=[next_rng, draws],
2618+
)(rng, size, alpha, beta)
26152619

26162620

26172621
class Weibull(PositiveContinuous):
@@ -2660,7 +2664,8 @@ class Weibull(PositiveContinuous):
26602664
Scale parameter (beta > 0).
26612665
"""
26622666

2663-
rv_op = weibull_beta
2667+
rv_type = WeibullBetaRV
2668+
rv_op = WeibullBetaRV.rv_op
26642669

26652670
@classmethod
26662671
def dist(cls, alpha, beta, *args, **kwargs):

0 commit comments

Comments
 (0)