Skip to content

Commit

Permalink
Annotate potentially expensive methods with @jax.named_call to enable
Browse files Browse the repository at this point in the history
identification of hotspots.

PiperOrigin-RevId: 604690173
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Feb 6, 2024
1 parent c4e48fb commit 4b44fd8
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/experimental/autobnn/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

import flax
from flax import linen as nn
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree # pylint: disable=g-importing-member,g-multiple-import
from tensorflow_probability.python.experimental.autobnn import likelihoods
from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib


@jax.named_call
def log_prior_of_parameters(params, distributions) -> Float:
"""Return the prior of the parameters according to the distributions."""
if 'params' in params:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def summary(self) -> str:
summaries = [self.net_.summarize(p) for p in params_per_particle]
return '\n'.join(summaries)

@jax.named_call
def predict_quantiles(
self, X: jax.Array, q=(2.5, 50.0, 97.5), axis: tuple[int, ...] = (0,) # pylint: disable=invalid-name
) -> jax.Array:
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_probability/python/experimental/autobnn/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ============================================================================
"""`Leaf` BNNs, most of which correspond to some known GP kernel."""

import functools
from flax import linen as nn
from flax.linen import initializers
import jax
Expand Down Expand Up @@ -104,10 +105,12 @@ def distributions(self):
}
return super().distributions() | d

@functools.partial(jax.named_call, name='OneLayer::penultimate')
def penultimate(self, inputs):
y = self.input_warping(inputs)
return self.activation_function(self.dense1(y))

@functools.partial(jax.named_call, name='OneLayer::__call__')
def __call__(self, inputs, deterministic=True):
return self.dense2(self.penultimate(inputs))

Expand Down Expand Up @@ -142,6 +145,7 @@ def distributions(self):
},
}

@functools.partial(jax.named_call, name='RBF::__call__')
def __call__(self, inputs, deterministic=True):
return self.amplitude * self.dense2(self.penultimate(inputs))

Expand Down Expand Up @@ -214,6 +218,7 @@ def bias_init(seed, shape, dtype=jnp.float32):
for _ in range(self.degree)]
super().setup()

@functools.partial(jax.named_call, name='Polynomial::penultimate')
def penultimate(self, inputs):
x = inputs - self.shift
ys = jnp.stack([h(x) for h in self.hiddens], axis=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def distributions(self):
"""Like BayesianModule::distributions but for the model's parameters."""
return {}

@jax.named_call
def log_likelihood(
self, params, nn_out: jax.Array, observations: jax.Array
) -> jax.Array:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# ============================================================================
"""Flax.linen modules for combining BNNs."""

import functools
from typing import Optional
from flax import linen as nn
import jax
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import likelihoods
Expand Down Expand Up @@ -54,6 +56,7 @@ def set_likelihood_model(self, likelihood_model: likelihoods.LikelihoodModel):
for b in self.bnns:
b.set_likelihood_model(dummy_ll_model)

@jax.named_call
def log_prior(self, params):
if 'params' in params:
params = params['params']
Expand Down Expand Up @@ -114,6 +117,7 @@ def penultimate(self, inputs):
class Add(MultipliableBnnOperator):
"""Add two or more BNNs."""

@functools.partial(jax.named_call, name='Add::penultimate')
def penultimate(self, inputs):
penultimates = [b.penultimate(inputs) for b in self.bnns]
return jnp.sum(jnp.stack(penultimates, axis=-1), axis=-1)
Expand Down Expand Up @@ -147,13 +151,15 @@ def distributions(self):
'bnn_weights': dirichlet_lib.Dirichlet(concentration=concentration)
}

@functools.partial(jax.named_call, name='WeightedSum::penultimate')
def penultimate(self, inputs):
penultimates = [
b.penultimate(inputs) * self.bnn_weights[0, i]
for i, b in enumerate(self.bnns)
]
return jnp.sum(jnp.stack(penultimates, axis=-1), axis=-1)

@functools.partial(jax.named_call, name='WeightedSum::__call__')
def __call__(self, inputs, deterministic=True):
return jnp.sum(
jnp.stack(
Expand Down Expand Up @@ -217,6 +223,7 @@ def distributions(self):
}
}

@functools.partial(jax.named_call, name='Multiply::__call__')
def __call__(self, inputs, deterministic=True):
penultimates = [b.penultimate(inputs) for b in self.bnns]
return self.dense(jnp.prod(jnp.stack(penultimates, axis=-1), axis=-1))
Expand All @@ -235,6 +242,7 @@ def setup(self):
assert len(self.bnns) == 2
super().setup()

@jax.named_call
def __call__(self, inputs, deterministic=True):
time = inputs[..., self.change_index, jnp.newaxis]
y = (time - self.change_point) / self.slope
Expand Down Expand Up @@ -274,6 +282,7 @@ def setup(self):
assert len(self.time_series_xs) >= 2
super().setup()

@functools.partial(jax.named_call, name='LearnableChangePoint::__call__')
def __call__(self, inputs, deterministic=True):
time = inputs[..., self.change_index, jnp.newaxis]
y = (time - self.change_point) / self.change_slope
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def log_density(params, *, seed=None):
inverse_log_det_jacobian=ildj)


@jax.named_call
def fit_bnn_map(
net: bnn.BNN,
seed: jax.Array,
Expand Down Expand Up @@ -125,6 +126,7 @@ def _filter_stuck_chains(params):
return jax.tree_map(lambda x: x[best_two], params)


@jax.named_call
def fit_bnn_vi(
net: bnn.BNN,
seed: jax.Array,
Expand All @@ -148,6 +150,7 @@ def fit_bnn_vi(
return params, {'loss': loss}


@jax.named_call
def fit_bnn_mcmc(
net: bnn.BNN,
seed: jax.Array,
Expand Down

0 comments on commit 4b44fd8

Please sign in to comment.