diff --git a/bgflow/nn/flow/inverted.py b/bgflow/nn/flow/inverted.py index c73a6463..f22997db 100644 --- a/bgflow/nn/flow/inverted.py +++ b/bgflow/nn/flow/inverted.py @@ -1,4 +1,3 @@ - from .base import Flow __all__ = ["InverseFlow"] @@ -12,12 +11,13 @@ class InverseFlow(Flow): delegate : Flow The flow to invert. """ + def __init__(self, delegate): super().__init__() self._delegate = delegate def _forward(self, *xs, **kwargs): - return self._delegate._inverse(*xs, **kwargs) - + return self._delegate(*xs, inverse=True, **kwargs) + def _inverse(self, *xs, **kwargs): - return self._delegate._forward(*xs, **kwargs) \ No newline at end of file + return self._delegate(*xs, inverse=False, **kwargs) diff --git a/bgflow/nn/flow/transformer/spline.py b/bgflow/nn/flow/transformer/spline.py index c5606ee0..ea7d4d7c 100644 --- a/bgflow/nn/flow/transformer/spline.py +++ b/bgflow/nn/flow/transformer/spline.py @@ -1,12 +1,19 @@ import torch +from typing import NamedTuple from .base import Transformer __all__ = [ "ConditionalSplineTransformer", + "DomainExtension", ] +class DomainExtension(NamedTuple): + tails: str = "linear" + tail_bound: float = 1.0 + + class ConditionalSplineTransformer(Transformer): def __init__( self, @@ -16,6 +23,7 @@ def __init__( right: float = 1.0, bottom: float = 0.0, top: float = 1.0, + domain_extension: DomainExtension = None, ): """ Spline transformer transforming variables in [left, right) into variables in [bottom, top). @@ -61,6 +69,7 @@ def __init__( self._right = right self._bottom = bottom self._top = top + self._domain_extension = domain_extension def _compute_params(self, x, y_dim): """Compute widths, heights, and slopes from x through the params_net. @@ -90,8 +99,13 @@ def _compute_params(self, x, y_dim): n_bins = params.shape[-1] // (y_dim * 3) widths, heights, slopes, noncircular_slopes = torch.split( params, - [n_bins * y_dim, n_bins * y_dim, n_bins * y_dim, self._n_noncircular(y_dim)], - dim=-1 + [ + n_bins * y_dim, + n_bins * y_dim, + n_bins * y_dim, + self._n_noncircular(y_dim), + ], + dim=-1, ) widths = widths.reshape(*batch_shape, y_dim, n_bins) heights = heights.reshape(*batch_shape, y_dim, n_bins) @@ -103,38 +117,38 @@ def _compute_params(self, x, y_dim): slopes[..., self._noncircular_indices(y_dim), -1] = noncircular_slopes return widths, heights, slopes - def _forward(self, x, y, *args, **kwargs): - from nflows.transforms.splines import rational_quadratic_spline + def forward(self, x, y, *args, inverse=False, **ignored_kwargs): + from nflows.transforms.splines import ( + rational_quadratic_spline, + unconstrained_rational_quadratic_spline, + ) widths, heights, slopes = self._compute_params(x, y.shape[-1]) - z, dlogp = rational_quadratic_spline( - y, - widths, - heights, - slopes, - inverse=True, - left=self._left, - right=self._right, - top=self._top, - bottom=self._bottom, - ) - return z, dlogp.sum(dim=-1, keepdim=True) - def _inverse(self, x, y, *args, **kwargs): - from nflows.transforms.splines import rational_quadratic_spline + kwargs = { + "unnormalized_widths": widths, + "unnormalized_heights": heights, + "unnormalized_derivatives": slopes, + "inverse": inverse, + } + + if self._domain_extension is None: + kwargs = { + **kwargs, + "left": self._left, + "right": self._right, + "top": self._top, + "bottom": self._bottom, + } + z, dlogp = rational_quadratic_spline(y, **kwargs) + else: + kwargs = { + **kwargs, + "tails": self._domain_extension.tails, + "tail_bound": self._domain_extension.tail_bound, + } + z, dlogp = unconstrained_rational_quadratic_spline(y, **kwargs) - widths, heights, slopes = self._compute_params(x, y.shape[-1]) - z, dlogp = rational_quadratic_spline( - y, - widths, - heights, - slopes, - inverse=False, - left=self._left, - right=self._right, - top=self._top, - bottom=self._bottom, - ) return z, dlogp.sum(dim=-1, keepdim=True) def _n_noncircular(self, y_dim):