Skip to content

Commit 3f3aeb9

Browse files
authored
Register the overloads added by CustomDist in worker processes (#7241)
1 parent 7c369c8 commit 3f3aeb9

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

pymc/smc/sampling.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535

3636
from pymc.backends.arviz import dict_to_dataset, to_inference_data
3737
from pymc.backends.base import MultiTrace
38+
from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV
39+
from pymc.distributions.distribution import _support_point
40+
from pymc.logprob.abstract import _icdf, _logcdf, _logprob
3841
from pymc.model import Model, modelcontext
3942
from pymc.sampling.parallel import _cpu_count
4043
from pymc.smc.kernels import IMH
@@ -346,11 +349,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
346349
# main process and our worker functions
347350
_progress = manager.dict()
348351

352+
# check if model contains CustomDistributions defined without dist argument
353+
custom_methods = _find_custom_dist_dispatch_methods(params[3])
354+
349355
# "manually" (de)serialize params before/after multiprocessing
350356
params = tuple(cloudpickle.dumps(p) for p in params)
351357
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
352358

353-
with ProcessPoolExecutor(max_workers=cores) as executor:
359+
with ProcessPoolExecutor(
360+
max_workers=cores,
361+
initializer=_register_custom_methods,
362+
initargs=(custom_methods,),
363+
) as executor:
354364
for c in range(chains): # iterate over the jobs we need to run
355365
# set visible false so we don't have a lot of bars all at once:
356366
task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0")
@@ -383,3 +393,32 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
383393
)
384394

385395
return tuple(cloudpickle.loads(r.result()) for r in done)
396+
397+
398+
def _find_custom_dist_dispatch_methods(model):
399+
custom_methods = {}
400+
for rv in model.basic_RVs:
401+
rv_type = rv.owner.op
402+
cls = type(rv_type)
403+
if isinstance(rv_type, CustomDistRV | CustomSymbolicDistRV):
404+
custom_methods[cloudpickle.dumps(cls)] = (
405+
cloudpickle.dumps(_logprob.registry.get(cls, None)),
406+
cloudpickle.dumps(_logcdf.registry.get(cls, None)),
407+
cloudpickle.dumps(_icdf.registry.get(cls, None)),
408+
cloudpickle.dumps(_support_point.registry.get(cls, None)),
409+
)
410+
411+
return custom_methods
412+
413+
414+
def _register_custom_methods(custom_methods):
415+
for cls, (logprob, logcdf, icdf, support_point) in custom_methods.items():
416+
cls = cloudpickle.loads(cls)
417+
if logprob is not None:
418+
_logprob.register(cls, cloudpickle.loads(logprob))
419+
if logcdf is not None:
420+
_logcdf.register(cls, cloudpickle.loads(logcdf))
421+
if icdf is not None:
422+
_icdf.register(cls, cloudpickle.loads(icdf))
423+
if support_point is not None:
424+
_support_point.register(cls, cloudpickle.loads(support_point))

tests/smc/test_smc.py

+15
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,21 @@ def test_unobserved_categorical(self):
134134

135135
assert np.all(np.median(trace["mu"], axis=0) == [1, 2])
136136

137+
def test_parallel_custom(self):
138+
def _logp(value, mu):
139+
return -((value - mu) ** 2)
140+
141+
def _random(mu, rng=None, size=None):
142+
return rng.normal(loc=mu, scale=1, size=size)
143+
144+
def _dist(mu, size=None):
145+
return pm.Normal.dist(mu, 1, size=size)
146+
147+
with pm.Model():
148+
mu = pm.CustomDist("mu", 0, logp=_logp, dist=_dist)
149+
pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2])
150+
pm.sample_smc(draws=6, cores=2)
151+
137152
def test_marginal_likelihood(self):
138153
"""
139154
Verifies that the log marginal likelihood function

0 commit comments

Comments
 (0)