|
35 | 35 |
|
36 | 36 | from pymc.backends.arviz import dict_to_dataset, to_inference_data
|
37 | 37 | from pymc.backends.base import MultiTrace
|
| 38 | +from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV |
| 39 | +from pymc.distributions.distribution import _support_point |
| 40 | +from pymc.logprob.abstract import _icdf, _logcdf, _logprob |
38 | 41 | from pymc.model import Model, modelcontext
|
39 | 42 | from pymc.sampling.parallel import _cpu_count
|
40 | 43 | from pymc.smc.kernels import IMH
|
@@ -346,11 +349,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
|
346 | 349 | # main process and our worker functions
|
347 | 350 | _progress = manager.dict()
|
348 | 351 |
|
| 352 | + # check if model contains CustomDistributions defined without dist argument |
| 353 | + custom_methods = _find_custom_dist_dispatch_methods(params[3]) |
| 354 | + |
349 | 355 | # "manually" (de)serialize params before/after multiprocessing
|
350 | 356 | params = tuple(cloudpickle.dumps(p) for p in params)
|
351 | 357 | kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
|
352 | 358 |
|
353 |
| - with ProcessPoolExecutor(max_workers=cores) as executor: |
| 359 | + with ProcessPoolExecutor( |
| 360 | + max_workers=cores, |
| 361 | + initializer=_register_custom_methods, |
| 362 | + initargs=(custom_methods,), |
| 363 | + ) as executor: |
354 | 364 | for c in range(chains): # iterate over the jobs we need to run
|
355 | 365 | # set visible false so we don't have a lot of bars all at once:
|
356 | 366 | task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0")
|
@@ -383,3 +393,32 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
|
383 | 393 | )
|
384 | 394 |
|
385 | 395 | return tuple(cloudpickle.loads(r.result()) for r in done)
|
| 396 | + |
| 397 | + |
| 398 | +def _find_custom_dist_dispatch_methods(model): |
| 399 | + custom_methods = {} |
| 400 | + for rv in model.basic_RVs: |
| 401 | + rv_type = rv.owner.op |
| 402 | + cls = type(rv_type) |
| 403 | + if isinstance(rv_type, CustomDistRV | CustomSymbolicDistRV): |
| 404 | + custom_methods[cloudpickle.dumps(cls)] = ( |
| 405 | + cloudpickle.dumps(_logprob.registry.get(cls, None)), |
| 406 | + cloudpickle.dumps(_logcdf.registry.get(cls, None)), |
| 407 | + cloudpickle.dumps(_icdf.registry.get(cls, None)), |
| 408 | + cloudpickle.dumps(_support_point.registry.get(cls, None)), |
| 409 | + ) |
| 410 | + |
| 411 | + return custom_methods |
| 412 | + |
| 413 | + |
| 414 | +def _register_custom_methods(custom_methods): |
| 415 | + for cls, (logprob, logcdf, icdf, support_point) in custom_methods.items(): |
| 416 | + cls = cloudpickle.loads(cls) |
| 417 | + if logprob is not None: |
| 418 | + _logprob.register(cls, cloudpickle.loads(logprob)) |
| 419 | + if logcdf is not None: |
| 420 | + _logcdf.register(cls, cloudpickle.loads(logcdf)) |
| 421 | + if icdf is not None: |
| 422 | + _icdf.register(cls, cloudpickle.loads(icdf)) |
| 423 | + if support_point is not None: |
| 424 | + _support_point.register(cls, cloudpickle.loads(support_point)) |
0 commit comments