diff --git a/tensorflow_probability/python/experimental/autobnn/BUILD b/tensorflow_probability/python/experimental/autobnn/BUILD index db2c285237..d74d59673f 100644 --- a/tensorflow_probability/python/experimental/autobnn/BUILD +++ b/tensorflow_probability/python/experimental/autobnn/BUILD @@ -99,6 +99,8 @@ py_test( srcs = ["estimators_test.py"], deps = [ ":estimators", + ":kernels", + ":operators", "//tensorflow_probability/python/internal:test_util", ], ) diff --git a/tensorflow_probability/python/experimental/autobnn/estimators.py b/tensorflow_probability/python/experimental/autobnn/estimators.py index 17b2dadce9..d1d91b5637 100644 --- a/tensorflow_probability/python/experimental/autobnn/estimators.py +++ b/tensorflow_probability/python/experimental/autobnn/estimators.py @@ -14,7 +14,7 @@ # ============================================================================ """Estimator classes for training BNN models using Bayeux.""" -from typing import Any, Mapping, Optional, Sequence +from typing import Any, Mapping, Optional, Sequence, Union import jax import jax.numpy as jnp @@ -33,14 +33,14 @@ class _AutoBnnEstimator: def __init__( self, - model_name: str, + model_or_name: Union[str, bnn.BNN], likelihood_model: str, seed: jax.Array, width: int = 50, periods: Sequence[ArrayLike] = (12.0,), likelihood_kwargs: Optional[Mapping[str, Any]] = None, ): - self.model_name = model_name + self.model_or_name = model_or_name self.likelihood_model = likelihood_model self.width = width self.periods = periods @@ -91,7 +91,7 @@ def fit(self, X: jax.Array, y: jax.Array) -> '_AutoBnnEstimator': # pylint: dis self.likelihood_model, self.likelihood_kwargs ) self.net_ = models.make_model( - model_name=self.model_name, + model_name=self.model_or_name, likelihood_model=self.likelihood_, time_series_xs=X, width=self.width, @@ -172,11 +172,34 @@ def predict_quantiles( class AutoBnnMapEstimator(_AutoBnnEstimator): - """Implementation of a MAP estimator for the BNN.""" + """Implementation of a MAP estimator for the BNN. + + Example usage: + + estimator = estimators.AutoBnnMapEstimator( + model_or_name='linear_plus_periodic', + likelihood_model='normal_likelihood_logistic_noise', + seed=jax.random.PRNGKey(42), + width=25, + num_particles=32, + num_iters=1000, + ) + estimator.fit(x_train, y_train) + low, mid, high = estimator.predict_quantiles(x_train) + + Or: + + estimator = estimators.AutoBnnMapEstimator( + model_or_name=operators.Add( + bnns=(kernels.LinearBNN(width=50), + kernels.PeriodicBNN(width=50, period=12))), + likelihood_model='normal_likelihood_lognormal_noise', + seed=jax.random.PRNGKey(123)) + """ def __init__( self, - model_name: str, + model_or_name: Union[str, bnn.BNN], likelihood_model: str, seed: jax.Array, width: int = 50, @@ -188,7 +211,7 @@ def __init__( **unused_kwargs, ): super().__init__( - model_name=model_name, + model_or_name=model_or_name, likelihood_model=likelihood_model, seed=seed, width=width, @@ -213,7 +236,7 @@ class AutoBnnMCMCEstimator(_AutoBnnEstimator): def __init__( self, - model_name: str, + model_or_name: Union[str, bnn.BNN], likelihood_model: str, seed: jax.Array, width: int = 50, @@ -224,7 +247,7 @@ def __init__( **unused_kwargs, ): super().__init__( - model_name=model_name, + model_or_name=model_or_name, likelihood_model=likelihood_model, seed=seed, width=width, @@ -244,7 +267,7 @@ class AutoBnnVIEstimator(_AutoBnnEstimator): def __init__( self, - model_name: str, + model_or_name: Union[str, bnn.BNN], likelihood_model: str, seed: jax.Array, width: int = 50, @@ -255,7 +278,7 @@ def __init__( **unused_kwargs, ): super().__init__( - model_name=model_name, + model_or_name=model_or_name, likelihood_model=likelihood_model, seed=seed, width=width, diff --git a/tensorflow_probability/python/experimental/autobnn/estimators_test.py b/tensorflow_probability/python/experimental/autobnn/estimators_test.py index c03faf4682..c1d9d28ef3 100644 --- a/tensorflow_probability/python/experimental/autobnn/estimators_test.py +++ b/tensorflow_probability/python/experimental/autobnn/estimators_test.py @@ -17,6 +17,8 @@ import jax import numpy as np from tensorflow_probability.python.experimental.autobnn import estimators +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import operators from tensorflow_probability.python.experimental.autobnn import util from tensorflow_probability.python.internal import test_util @@ -28,7 +30,7 @@ def test_train_map(self): x_train, y_train = util.load_fake_dataset() autobnn = estimators.AutoBnnMapEstimator( - model_name='linear_plus_periodic', + model_or_name='linear_plus_periodic', likelihood_model='normal_likelihood_logistic_noise', seed=seed, width=5, @@ -51,7 +53,7 @@ def test_train_mcmc(self): x_train, y_train = util.load_fake_dataset() autobnn = estimators.AutoBnnMCMCEstimator( - model_name='linear_plus_periodic', + model_or_name='linear_plus_periodic', likelihood_model='normal_likelihood_logistic_noise', seed=seed, width=5, @@ -68,12 +70,41 @@ def test_train_mcmc(self): # TODO(colcarroll): Add test for AutoBnnVIEstimator. + def test_custom_model(self): + seed = jax.random.PRNGKey(20231018) + x_train, y_train = util.load_fake_dataset() + + model = operators.Add( + bnns=(kernels.PeriodicBNN(width=50, period=12), + kernels.LinearBNN(width=50), + kernels.MaternBNN(width=50))) + + autobnn = estimators.AutoBnnMapEstimator( + model_or_name=model, + likelihood_model='normal_likelihood_logistic_noise', + seed=seed, + width=5, + num_particles=8, + num_iters=100, + ) + self.assertFalse(autobnn.check_is_fitted()) + autobnn.fit(x_train, y_train) + self.assertTrue(autobnn.check_is_fitted()) + self.assertEqual(autobnn.diagnostics_['loss'].shape, (8, 100)) + lo, mid, hi = autobnn.predict_quantiles(x_train) + np.testing.assert_array_less(lo, mid) + np.testing.assert_array_less(mid, hi) + self.assertEqual( + autobnn.summary(), + '\n'.join(['(Periodic(period=12.00)#Linear#Matern(2.5))'] * 8) + ) + def test_summary(self): seed = jax.random.PRNGKey(20231018) x_train, y_train = util.load_fake_dataset() autobnn = estimators.AutoBnnMapEstimator( - model_name='sum_of_stumps', + model_or_name='sum_of_stumps', likelihood_model='normal_likelihood_logistic_noise', seed=seed, width=5, diff --git a/tensorflow_probability/python/experimental/autobnn/models.py b/tensorflow_probability/python/experimental/autobnn/models.py index 63c7165988..31d32f37bb 100644 --- a/tensorflow_probability/python/experimental/autobnn/models.py +++ b/tensorflow_probability/python/experimental/autobnn/models.py @@ -20,7 +20,7 @@ linear components. """ import functools -from typing import Sequence +from typing import Sequence, Union import jax.numpy as jnp from tensorflow_probability.python.experimental.autobnn import bnn from tensorflow_probability.python.experimental.autobnn import bnn_tree @@ -292,18 +292,22 @@ def make_multilayer( def make_model( - model_name: str, + model_name: Union[str, bnn.BNN], likelihood_model: likelihoods.LikelihoodModel, time_series_xs: Array, width: int = 5, periods: Sequence[float] = (0.1,), ) -> bnn.BNN: """Create a BNN model by name.""" - m = MODEL_NAME_TO_MAKE_FUNCTION[model_name]( - time_series_xs=time_series_xs, - width=width, - periods=periods, - num_outputs=likelihood_model.num_outputs(), - ) + if isinstance(model_name, str): + m = MODEL_NAME_TO_MAKE_FUNCTION[model_name]( + time_series_xs=time_series_xs, + width=width, + periods=periods, + num_outputs=likelihood_model.num_outputs(), + ) + else: + m = model_name + m.set_likelihood_model(likelihood_model) return m