Skip to content

Commit

Permalink
Add to_pymc method to distributions (#523)
Browse files Browse the repository at this point in the history
* add to_pymc method

* add to_pymc method
  • Loading branch information
aloctavodia authored Aug 15, 2024
1 parent b39b18a commit 87b22e0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
31 changes: 31 additions & 0 deletions preliz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Parent classes for all families.
"""
# pylint: disable=no-member
# pylint: disable=import-outside-toplevel
from collections import namedtuple
from copy import copy

Expand Down Expand Up @@ -257,6 +258,36 @@ def hdi(self, mass=0.94, fmt=".2f"):
else:
return None

def to_pymc(self, name=None, **kwargs):
"""
Convert the PreliZ distribution to a PyMC distribution.
name : str
Name of PyMC distribution. Needed if inside Model context
kwargs : PyMC distributions properties
kwargs are used to specify properties such as shape or dims
Returns
-------
PyMC distribution
"""
try:
import pymc.distributions as pm_dists
from pymc.model import Model

model = Model.get_context(error_if_none=False)

if model is None:
return getattr(pm_dists, self.__class__.__name__).dist(**self.params_dict, **kwargs)
else:
return getattr(pm_dists, self.__class__.__name__)(
name, **self.params_dict, **kwargs
)
except ImportError:
pass

return None

def _check_endpoints(self, lower, upper, raise_error=True):
"""
Evaluate if the lower and upper values are in the support of the distribution
Expand Down
10 changes: 10 additions & 0 deletions preliz/tests/test_distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=redefined-outer-name

import pytest
from pymc import Model
from numpy.testing import assert_almost_equal
import numpy as np
from test_helper import run_notebook
Expand Down Expand Up @@ -293,3 +294,12 @@ def test_plot_interactive(capsys, a_few_poissons):
captured = capsys.readouterr()
assert "RuntimeError" in captured.out
run_notebook("plot_interactive.ipynb")


def test_to_pymc():
with Model() as model:
Gamma(1, 1).to_pymc("a", shape=(2, 2))

assert model.basic_RVs[0].name == "a"
assert model.basic_RVs[0].ndim == 2
assert Normal(0, 1).to_pymc(shape=2).ndim == 1

0 comments on commit 87b22e0

Please sign in to comment.