diff --git a/preliz/ppls/pymc_io.py b/preliz/ppls/pymc_io.py index 007bf4ab..484e7841 100644 --- a/preliz/ppls/pymc_io.py +++ b/preliz/ppls/pymc_io.py @@ -7,12 +7,15 @@ try: from pytensor.tensor import vector, TensorConstant + from pytensor.graph.basic import ancestors from pymc import logp, compile_pymc from pymc.util import is_transformed_name, get_untransformed_name except ModuleNotFoundError: pass from preliz.internal.optimization import get_distributions +from preliz.distributions import Gamma, Normal, HalfNormal +from preliz.unidimensional.mle import mle def backfitting(prior, p_model, var_info2): @@ -117,15 +120,11 @@ def get_model_information(model): # pylint: disable=too-many-locals pymc_to_preliz = get_pymc_to_preliz() rvs_to_values = model.rvs_to_values - for r_v in model.free_RVs: - if not non_constant_parents(r_v): - free_rvs.append(r_v) - for r_v in model.free_RVs: r_v_eval = r_v.eval() size = r_v_eval.size shape = r_v_eval.shape - nc_parents = non_constant_parents(r_v) + nc_parents = non_constant_parents(r_v, model.free_RVs) name = r_v.owner.op.name dist = pymc_to_preliz[name] @@ -201,11 +200,55 @@ def reshape_params(model, var_info, p_model, params): return value -def non_constant_parents(var_): +def non_constant_parents(var_, free_rvs): """Find the parents of a variable that are not constant.""" parents = [] for variable in var_.get_parents()[0].inputs[2:]: if not isinstance(variable, TensorConstant): - parents.append(variable.owner.inputs[0]) - + for free_rv in free_rvs: + if free_rv in list(ancestors([variable])) and free_rv not in parents: + parents.append(free_rv) return parents + + +def posterior_to_prior(model, posterior, alternative=None): + """ + Updates the priors of a probabilistic model by fitting them to posterior data, using either predefined or + user-specified alternative distributions. It selects the best-fitting distribution for each variable based + on maximum likelihood estimation (MLE). The result is a model with priors better aligned to the observed data. + + Parameters + ---------- + model : A PyMC model + A probabilistic model + + posterior : Posterior samples + InferenceData from with the posterior group + + alternative : "auto", list, dict, defaults to None + Users can add the model variables to consider alternative distributions while fitting samples + + """ + + model_info = get_model_information(model)[2] + parsed_info = [(dist, var) for var, dist in model_info.items()] + new_priors = [] + + for dist, var in parsed_info: + dists = [model_info[var]] + + if alternative == "auto": + dists += [Normal(), HalfNormal(), Gamma()] + elif isinstance(alternative, list): + dists += alternative + elif isinstance(alternative, dict): + dists += alternative.get(var, []) + if len(dists) == 1: + dists[0]._fit_mle(posterior[var].values) + new_priors.append((dists[0], var)) + else: + idx = mle(dists, posterior[var].values, plot=False)[0] + new_priors.append((dists[idx[0]], var)) + + new_model = "\n".join(f"{var} = {new_prior}" for new_prior, var in new_priors) + return new_model