Skip to content

Commit

Permalink
Let the user supply their own custom BNN model to AutoBNN*Estimator.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600520002
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Jan 22, 2024
1 parent 17cc66c commit 64a70d0
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 22 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/experimental/autobnn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ py_test(
srcs = ["estimators_test.py"],
deps = [
":estimators",
":kernels",
":operators",
"//tensorflow_probability/python/internal:test_util",
],
)
Expand Down
45 changes: 34 additions & 11 deletions tensorflow_probability/python/experimental/autobnn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ============================================================================
"""Estimator classes for training BNN models using Bayeux."""

from typing import Any, Mapping, Optional, Sequence
from typing import Any, Mapping, Optional, Sequence, Union

import jax
import jax.numpy as jnp
Expand All @@ -33,14 +33,14 @@ class _AutoBnnEstimator:

def __init__(
self,
model_name: str,
model_or_name: Union[str, bnn.BNN],
likelihood_model: str,
seed: jax.Array,
width: int = 50,
periods: Sequence[ArrayLike] = (12.0,),
likelihood_kwargs: Optional[Mapping[str, Any]] = None,
):
self.model_name = model_name
self.model_or_name = model_or_name
self.likelihood_model = likelihood_model
self.width = width
self.periods = periods
Expand Down Expand Up @@ -91,7 +91,7 @@ def fit(self, X: jax.Array, y: jax.Array) -> '_AutoBnnEstimator': # pylint: dis
self.likelihood_model, self.likelihood_kwargs
)
self.net_ = models.make_model(
model_name=self.model_name,
model_name=self.model_or_name,
likelihood_model=self.likelihood_,
time_series_xs=X,
width=self.width,
Expand Down Expand Up @@ -172,11 +172,34 @@ def predict_quantiles(


class AutoBnnMapEstimator(_AutoBnnEstimator):
"""Implementation of a MAP estimator for the BNN."""
"""Implementation of a MAP estimator for the BNN.
Example usage:
estimator = estimators.AutoBnnMapEstimator(
model_or_name='linear_plus_periodic',
likelihood_model='normal_likelihood_logistic_noise',
seed=jax.random.PRNGKey(42),
width=25,
num_particles=32,
num_iters=1000,
)
estimator.fit(x_train, y_train)
low, mid, high = estimator.predict_quantiles(x_train)
Or:
estimator = estimators.AutoBnnMapEstimator(
model_or_name=operators.Add(
bnns=(kernels.LinearBNN(width=50),
kernels.PeriodicBNN(width=50, period=12))),
likelihood_model='normal_likelihood_lognormal_noise',
seed=jax.random.PRNGKey(123))
"""

def __init__(
self,
model_name: str,
model_or_name: Union[str, bnn.BNN],
likelihood_model: str,
seed: jax.Array,
width: int = 50,
Expand All @@ -188,7 +211,7 @@ def __init__(
**unused_kwargs,
):
super().__init__(
model_name=model_name,
model_or_name=model_or_name,
likelihood_model=likelihood_model,
seed=seed,
width=width,
Expand All @@ -213,7 +236,7 @@ class AutoBnnMCMCEstimator(_AutoBnnEstimator):

def __init__(
self,
model_name: str,
model_or_name: Union[str, bnn.BNN],
likelihood_model: str,
seed: jax.Array,
width: int = 50,
Expand All @@ -224,7 +247,7 @@ def __init__(
**unused_kwargs,
):
super().__init__(
model_name=model_name,
model_or_name=model_or_name,
likelihood_model=likelihood_model,
seed=seed,
width=width,
Expand All @@ -244,7 +267,7 @@ class AutoBnnVIEstimator(_AutoBnnEstimator):

def __init__(
self,
model_name: str,
model_or_name: Union[str, bnn.BNN],
likelihood_model: str,
seed: jax.Array,
width: int = 50,
Expand All @@ -255,7 +278,7 @@ def __init__(
**unused_kwargs,
):
super().__init__(
model_name=model_name,
model_or_name=model_or_name,
likelihood_model=likelihood_model,
seed=seed,
width=width,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import jax
import numpy as np
from tensorflow_probability.python.experimental.autobnn import estimators
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import operators
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.python.internal import test_util

Expand All @@ -28,7 +30,7 @@ def test_train_map(self):
x_train, y_train = util.load_fake_dataset()

autobnn = estimators.AutoBnnMapEstimator(
model_name='linear_plus_periodic',
model_or_name='linear_plus_periodic',
likelihood_model='normal_likelihood_logistic_noise',
seed=seed,
width=5,
Expand All @@ -51,7 +53,7 @@ def test_train_mcmc(self):
x_train, y_train = util.load_fake_dataset()

autobnn = estimators.AutoBnnMCMCEstimator(
model_name='linear_plus_periodic',
model_or_name='linear_plus_periodic',
likelihood_model='normal_likelihood_logistic_noise',
seed=seed,
width=5,
Expand All @@ -68,12 +70,41 @@ def test_train_mcmc(self):

# TODO(colcarroll): Add test for AutoBnnVIEstimator.

def test_custom_model(self):
seed = jax.random.PRNGKey(20231018)
x_train, y_train = util.load_fake_dataset()

model = operators.Add(
bnns=(kernels.PeriodicBNN(width=50, period=12),
kernels.LinearBNN(width=50),
kernels.MaternBNN(width=50)))

autobnn = estimators.AutoBnnMapEstimator(
model_or_name=model,
likelihood_model='normal_likelihood_logistic_noise',
seed=seed,
width=5,
num_particles=8,
num_iters=100,
)
self.assertFalse(autobnn.check_is_fitted())
autobnn.fit(x_train, y_train)
self.assertTrue(autobnn.check_is_fitted())
self.assertEqual(autobnn.diagnostics_['loss'].shape, (8, 100))
lo, mid, hi = autobnn.predict_quantiles(x_train)
np.testing.assert_array_less(lo, mid)
np.testing.assert_array_less(mid, hi)
self.assertEqual(
autobnn.summary(),
'\n'.join(['(Periodic(period=12.00)#Linear#Matern(2.5))'] * 8)
)

def test_summary(self):
seed = jax.random.PRNGKey(20231018)
x_train, y_train = util.load_fake_dataset()

autobnn = estimators.AutoBnnMapEstimator(
model_name='sum_of_stumps',
model_or_name='sum_of_stumps',
likelihood_model='normal_likelihood_logistic_noise',
seed=seed,
width=5,
Expand Down
20 changes: 12 additions & 8 deletions tensorflow_probability/python/experimental/autobnn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
linear components.
"""
import functools
from typing import Sequence
from typing import Sequence, Union
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import bnn_tree
Expand Down Expand Up @@ -292,18 +292,22 @@ def make_multilayer(


def make_model(
model_name: str,
model_name: Union[str, bnn.BNN],
likelihood_model: likelihoods.LikelihoodModel,
time_series_xs: Array,
width: int = 5,
periods: Sequence[float] = (0.1,),
) -> bnn.BNN:
"""Create a BNN model by name."""
m = MODEL_NAME_TO_MAKE_FUNCTION[model_name](
time_series_xs=time_series_xs,
width=width,
periods=periods,
num_outputs=likelihood_model.num_outputs(),
)
if isinstance(model_name, str):
m = MODEL_NAME_TO_MAKE_FUNCTION[model_name](
time_series_xs=time_series_xs,
width=width,
periods=periods,
num_outputs=likelihood_model.num_outputs(),
)
else:
m = model_name

m.set_likelihood_model(likelihood_model)
return m

0 comments on commit 64a70d0

Please sign in to comment.