Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Bisect-k #54

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions bgflow/nn/flow/transformer/jax_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

try:
import jax
import jax.numpy as jnp
from jax import numpy as jnp, lax, vmap
import jax.dlpack
except ImportError:
jax = None
Expand Down Expand Up @@ -67,6 +67,24 @@ def _body(_, left_right):

return _inverted

def bisect_k(bijector, left_bound, right_bound, k=2, eps=1e-6):
"""k-bin-bisection search. k=2 is normal bisection"""
@vmap
def _inverted(target):
init = (left_bound, right_bound)
n_iters = jnp.ceil(-jnp.log(eps)/jnp.log(k)).astype(int)
def _body(_, left_right):
left_bound, right_bound = left_right
cand = jnp.linspace(left_bound, right_bound, k+1) # cand: candidates
pred = vmap(bijector)(cand[1:-1]) # Don't calculate the bounds, which we know the result for already
comp = jnp.concatenate([jnp.array([True]), pred < target, jnp.array([False])]) # Add in the known bounds
lbin = jnp.bitwise_and(comp[:-1], ~comp[1:]).argmax() # This should contain only one True, at the boundary point
left_bound, right_bound = lax.dynamic_slice(cand, (lbin,), (2,))
return left_bound, right_bound

return jax.lax.fori_loop(0, n_iters, _body, init)[0]

return _inverted

def invert_bijector(bijector, root_finder):
"""Inverts a bijector with a root finder
Expand Down Expand Up @@ -133,12 +151,12 @@ def _call(x, *params):
return _call


def bijector_with_approx_inverse(bijector, domain=None, eps=1e-8):
def bijector_with_approx_inverse(bijector, domain=None, eps=1e-8, k=2):
"""Wraps bijector with approximate inverse."""
if domain is None:
domain = (0, 1)
root_finder = functools.partial(
bisect,
bisect if k==2 else functools.partial(bisect_k, k=k),
left_bound=domain[0],
right_bound=domain[1],
eps=eps)
Expand Down Expand Up @@ -234,11 +252,11 @@ def nested_vmap(fn, indices):
return fn


def jax_compile(bijector, vmap_indices, backend, domain=None, bisection_eps=1e-8):
def jax_compile(bijector, vmap_indices, backend, domain=None, bisection_eps=1e-8, k=2):
"""Wraps simple JAX bijector into a transformer,
that can be used within the bgflow eco-system."""
compile_bijector = compose(functools.partial(jax.jit))
fwd, bwd = bijector_with_approx_inverse(nested_vmap(bijector, vmap_indices), domain, bisection_eps)
fwd, bwd = bijector_with_approx_inverse(nested_vmap(bijector, vmap_indices), domain, bisection_eps, k=k)
return tuple(map(compile_bijector, (fwd, bwd)))


Expand All @@ -249,14 +267,14 @@ def torch_to_jax_backend(backend):
return backend


def to_torch_impl_(bijector, vmap_indices, backend, domain=None, bisection_eps=1e-8):
def to_torch_impl_(bijector, vmap_indices, backend, domain=None, bisection_eps=1e-8, k=2):
"""Helper impl function that can be cashed according
to `vmap_indices` and `backend`"""
fwd, bwd = jax_compile(bijector, vmap_indices, backend, domain, bisection_eps)
fwd, bwd = jax_compile(bijector, vmap_indices, backend, domain, bisection_eps, k=k)
return tuple(map(wrap_jax_fun, (fwd, bwd)))


def to_torch(bijector, vmap_indices=None, domain=None, bisection_eps=1e-8):
def to_torch(bijector, vmap_indices=None, domain=None, bisection_eps=1e-8, k=2):
"""Converts a simple JAX bijector into a torch bijector with
- numerical inverses
- automatic computation of log det jac
Expand All @@ -270,7 +288,7 @@ def _cached(x):
if indices is None:
indices = tuple(range(len(x.shape)))
backend = torch_to_jax_backend(x.device.type)
return cached_compile(indices, backend, domain, bisection_eps)
return cached_compile(indices, backend, domain, bisection_eps, k=k)

def _fwd(x, *params):
assert_float32(x)
Expand All @@ -294,10 +312,10 @@ class JaxTransformer(Transformer):
bijector."""

def __init__(self, bijector, compute_params, reduce_jacobian=True,
domain=None, bisection_eps=1e-8):
domain=None, bisection_eps=1e-8, k=2):
super().__init__()
self._compute_params = compute_params
fwd, bwd = to_torch(bijector)
fwd, bwd = to_torch(bijector, domain=domain, bisection_eps=bisection_eps, k=k)
self.fwd = fwd
self.bwd = bwd
self.reduce_jacobian = reduce_jacobian
Expand Down