Skip to content

Commit

Permalink
Extend "to_pymc" to transformed variables (#544)
Browse files Browse the repository at this point in the history
* Extend to_pymc to transformed variables

* Extend to_pymc to transformed variables
  • Loading branch information
aloctavodia authored Sep 28, 2024
1 parent b6250d3 commit f8ec5cd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
37 changes: 31 additions & 6 deletions preliz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def eti(self, mass=0.94, fmt=".2f"):
if valid_scalar_params(self):
lower_tail, upper_tail = self.ppf([(1 - mass) / 2, 1 - (1 - mass) / 2])
if self.kind == "continuos" and fmt != "none":
print("hi!")
lower_tail = float(f"{lower_tail:{fmt}}")
upper_tail = float(f"{upper_tail:{fmt}}")
elif self.kind == "discrete":
Expand Down Expand Up @@ -271,22 +270,48 @@ def to_pymc(self, name=None, **kwargs):
-------
PyMC distribution
"""
pymc_dist = None

try:
import pymc.distributions as pm_dists
from pymc.model import Model

model = Model.get_context(error_if_none=False)

if self.__class__.__name__ == "Hurdle":
preliz_name = self.__class__.__name__ + self.dist.__class__.__name__
else:
preliz_name = self.__class__.__name__
pymc_class = getattr(pm_dists, preliz_name)

if model is None:
return getattr(pm_dists, self.__class__.__name__).dist(**self.params_dict, **kwargs)
if self.__class__.__name__ in ["Truncated", "Censored"]:
pymc_dist = pymc_class.dist(
self.dist.to_pymc(),
lower=self.params_dict["lower"],
upper=self.params_dict["upper"],
**kwargs,
)
else:
pymc_dist = pymc_class.dist(**self.params_dict, **kwargs)
else:
return getattr(pm_dists, self.__class__.__name__)(
name, **self.params_dict, **kwargs
)
if self.__class__.__name__ in ["Truncated", "Censored"]:
pymc_dist = pymc_class(
name,
getattr(pm_dists, self.dist.__class__.__name__).dist(
**self.dist.params_dict
),
lower=self.params_dict["lower"],
upper=self.params_dict["upper"],
**kwargs,
)
else:
pymc_dist = pymc_class(name, **self.params_dict, **kwargs)

except ImportError:
pass

return None
return pymc_dist

def _check_endpoints(self, lower, upper, raise_error=True):
"""
Expand Down
10 changes: 10 additions & 0 deletions preliz/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
ZeroInflatedPoisson,
Dirichlet,
MvNormal,
Truncated,
Censored,
Hurdle,
)


Expand Down Expand Up @@ -299,7 +302,14 @@ def test_plot_interactive(capsys, a_few_poissons):
def test_to_pymc():
with Model() as model:
Gamma(1, 1).to_pymc("a", shape=(2, 2))
Hurdle(Gamma(1, 1), psi=0.5).to_pymc("b", shape=1)
Truncated(Gamma(1, 1), lower=1).to_pymc("c")

assert model.basic_RVs[0].name == "a"
assert model.basic_RVs[0].ndim == 2
assert model.basic_RVs[1].name == "b"
assert model.basic_RVs[1].ndim == 1
assert model.basic_RVs[2].name == "c"
assert model.basic_RVs[2].ndim == 0
assert Normal(0, 1).to_pymc(shape=2).ndim == 1
assert Censored(Normal(0, 1), lower=0).to_pymc().ndim == 0

0 comments on commit f8ec5cd

Please sign in to comment.