Skip to content

Commit

Permalink
provide more options roulette
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Oct 23, 2023
1 parent e30b903 commit 0e2bd1c
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 45 deletions.
12 changes: 8 additions & 4 deletions preliz/internal/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def func(params, dist, x_vals):
init_vals = np.array(dist.params)[none_idx]
bounds = np.array(dist.params_support)[none_idx]
bounds = list(zip(*bounds))

opt = least_squares(func, x0=init_vals, args=(dist, x_vals), bounds=bounds)
params = get_params(dist, opt["x"], none_idx, fixed)
dist._parametrization(**params)
Expand All @@ -88,7 +89,8 @@ def func(params, dist, x_vals, ecdf):
bounds = list(zip(*bounds))

opt = least_squares(func, x0=init_vals, args=(dist, x_vals, ecdf), bounds=bounds)
dist._update(*opt["x"])
params = get_params(dist, opt["x"], none_idx, fixed)
dist._parametrization(**params)
loss = opt["cost"]
return loss

Expand Down Expand Up @@ -203,7 +205,7 @@ def get_distributions(dist_names):
return dists


def fit_to_ecdf(selected_distributions, x_vals, ecdf, mean, std, x_min, x_max):
def fit_to_ecdf(selected_distributions, x_vals, ecdf, mean, std, x_min, x_max, extra_pros):
"""
Minimize the difference between the cdf and the ecdf over a grid of values
defined by x_min and x_max
Expand All @@ -212,8 +214,10 @@ def fit_to_ecdf(selected_distributions, x_vals, ecdf, mean, std, x_min, x_max):
"""
fitted = Loss(len(selected_distributions))
for dist in selected_distributions:
if dist.__class__.__name__ == "BetaScaled":
update_bounds_beta_scaled(dist, x_min, x_max)
if dist.__class__.__name__ in extra_pros:
dist._parametrization(**extra_pros[dist.__class__.__name__])
if dist.__class__.__name__ == "BetaScaled":
update_bounds_beta_scaled(dist, x_min, x_max)

if dist._check_endpoints(x_min, x_max, raise_error=False):
none_idx, fixed = get_fixed_params(dist)
Expand Down
5 changes: 5 additions & 0 deletions preliz/tests/test_quartile_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from test_helper import run_notebook


def test_roulette():
run_notebook("quartile_int.ipynb")
2 changes: 1 addition & 1 deletion preliz/unidimensional/quartile_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_widgets(q1, q2, q3, dist_names=None):

if dist_names is None:

default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal"]
default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]

dist_names = [
"AsymmetricLaplace",
Expand Down
107 changes: 67 additions & 40 deletions preliz/unidimensional/roulette.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
pass
from ..internal.optimization import fit_to_ecdf, get_distributions
from ..internal.plot_helper import check_inside_notebook, representations
from ..internal.distribution_helper import process_extra


def roulette(x_min=0, x_max=10, nrows=10, ncols=11, figsize=None):
def roulette(x_min=0, x_max=10, nrows=10, ncols=11, dist_names=None, figsize=None):
"""
Prior elicitation for 1D distribution using the roulette method.
Expand All @@ -29,6 +30,10 @@ def roulette(x_min=0, x_max=10, nrows=10, ncols=11, figsize=None):
Number of rows for the grid. Defaults to 10.
ncols: Optional[int]
Number of columns for the grid. Defaults to 11.
dist_names: list
List of distributions names to be used in the elicitation. If None, almost all 1D
distributions available in PreliZ will be used. Some distributions like Uniform or
Cauchy are omitted by default.
figsize: Optional[Tuple[int, int]]
Figure size. If None it will be defined automatically.
Expand All @@ -44,8 +49,12 @@ def roulette(x_min=0, x_max=10, nrows=10, ncols=11, figsize=None):

check_inside_notebook(need_widget=True)

w_x_min, w_x_max, w_ncols, w_nrows, w_repr, w_distributions = get_widgets(
x_min, x_max, nrows, ncols
w_x_min, w_x_max, w_ncols, w_nrows, w_extra, w_repr, w_distributions = get_widgets(
x_min,
x_max,
nrows,
ncols,
dist_names,
)

output = widgets.Output()
Expand Down Expand Up @@ -90,6 +99,7 @@ def on_leave_fig_(_):
w_x_min.value,
w_x_max.value,
w_ncols.value,
w_extra.value,
ax_fit,
)

Expand All @@ -113,11 +123,12 @@ def on_value_change(change):
w_x_min.value,
w_x_max.value,
w_ncols.value,
w_extra.value,
ax_fit,
),
)

controls = widgets.VBox([w_x_min, w_x_max, w_nrows, w_ncols])
controls = widgets.VBox([w_x_min, w_x_max, w_nrows, w_ncols, w_extra])

display(widgets.HBox([controls, w_repr, w_distributions])) # pylint:disable=undefined-variable

Expand Down Expand Up @@ -200,11 +211,12 @@ def __call__(self, event):
self.fig.canvas.draw()


def on_leave_fig(canvas, grid, dist_names, kind_plot, x_min, x_max, ncols, ax):
def on_leave_fig(canvas, grid, dist_names, kind_plot, x_min, x_max, ncols, extra, ax):
x_min = float(x_min)
x_max = float(x_max)
ncols = float(ncols)
x_range = x_max - x_min
extra_pros = process_extra(extra)

x_vals, ecdf, mean, std, filled_columns = weights_to_ecdf(grid.weights, x_min, x_range, ncols)

Expand All @@ -222,6 +234,7 @@ def on_leave_fig(canvas, grid, dist_names, kind_plot, x_min, x_max, ncols, ax):
std,
x_min,
x_max,
extra_pros,
)

if fitted_dist is None:
Expand Down Expand Up @@ -280,9 +293,10 @@ def reset_dist_panel(x_min, x_max, ax, yticks):
ax.autoscale_view()


def get_widgets(x_min, x_max, nrows, ncols):
def get_widgets(x_min, x_max, nrows, ncols, dist_names):

width_entry_text = widgets.Layout(width="150px")
width_repr_text = widgets.Layout(width="250px")
width_distribution_text = widgets.Layout(width="150px", height="125px")

w_x_min = widgets.FloatText(
Expand Down Expand Up @@ -319,6 +333,14 @@ def get_widgets(x_min, x_max, nrows, ncols):
layout=width_entry_text,
)

w_extra = widgets.Textarea(
value="",
placeholder="Pass extra parameters",
description="params:",
disabled=False,
layout=width_repr_text,
)

w_repr = widgets.RadioButtons(
options=["pdf", "cdf", "ppf"],
value="pdf",
Expand All @@ -327,39 +349,44 @@ def get_widgets(x_min, x_max, nrows, ncols):
layout=width_entry_text,
)

default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]

dist_names = [
"AsymmetricLaplace",
"BetaScaled",
"ChiSquared",
"ExGaussian",
"Exponential",
"Gamma",
"Gumbel",
"HalfNormal",
"HalfStudentT",
"InverseGamma",
"Laplace",
"LogNormal",
"Logistic",
# "LogitNormal", # fails if we add chips at x_value= 1
"Moyal",
"Normal",
"Pareto",
"Rice",
"SkewNormal",
"StudentT",
"Triangular",
"VonMises",
"Wald",
"Weibull",
"BetaBinomial",
"DiscreteWeibull",
"Geometric",
"NegativeBinomial",
"Poisson",
]
if dist_names is None:

default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]

dist_names = [
"AsymmetricLaplace",
"BetaScaled",
"ChiSquared",
"ExGaussian",
"Exponential",
"Gamma",
"Gumbel",
"HalfNormal",
"HalfStudentT",
"InverseGamma",
"Laplace",
"LogNormal",
"Logistic",
# "LogitNormal", # fails if we add chips at x_value= 1
"Moyal",
"Normal",
"Pareto",
"Rice",
"SkewNormal",
"StudentT",
"Triangular",
"VonMises",
"Wald",
"Weibull",
"BetaBinomial",
"DiscreteWeibull",
"Geometric",
"NegativeBinomial",
"Poisson",
]

else:
default_dist = dist_names

w_distributions = widgets.SelectMultiple(
options=dist_names,
Expand All @@ -369,4 +396,4 @@ def get_widgets(x_min, x_max, nrows, ncols):
layout=width_distribution_text,
)

return w_x_min, w_x_max, w_ncols, w_nrows, w_repr, w_distributions
return w_x_min, w_x_max, w_ncols, w_nrows, w_extra, w_repr, w_distributions

0 comments on commit 0e2bd1c

Please sign in to comment.