Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Volume-Preserving Flows #39

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bgflow/bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def sample(
results = list(x)

if with_latent:
results.append(*z)
results.append(z)
if with_dlogp:
results.append(dlogp)
if with_energy:
Expand Down
5 changes: 3 additions & 2 deletions bgflow/factory/conditioner_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

from typing import Mapping
import torch
import bgflow as bg
from ..nn.periodic import WrapPeriodic
Expand All @@ -17,7 +18,7 @@ def make_conditioners(
transformer_kwargs={},
conditioner_type="dense",
**kwargs
):
) -> Mapping[str, torch.nn.Module]:
"""Create coupling layer conditioners for a given transformer type,
taking care of circular and non-circular tensors.

Expand All @@ -43,7 +44,7 @@ def make_conditioners(

Returns
-------
transformer : bg.Transformer
conditioners : Mapping[str, torch.nn.Module]
"""
net_factory = CONDITIONER_FACTORIES[conditioner_type]
dim_out_factory = CONDITIONER_OUT_DIMS[transformer_type]
Expand Down
96 changes: 94 additions & 2 deletions bgflow/factory/generator_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""High-level Builder API for Boltzmann generators."""

import contextlib
import copy
import warnings
from typing import Mapping, Sequence

Expand All @@ -9,7 +11,8 @@
from ..nn.flow.sequential import SequentialFlow
from ..nn.flow.coupling import SetConstantFlow
from ..nn.flow.transformer.spline import ConditionalSplineTransformer
from ..nn.flow.coupling import CouplingFlow, SplitFlow, WrapFlow, MergeFlow
from ..nn.flow.transformer.affine import AffineTransformer
from ..nn.flow.coupling import CouplingFlow, SplitFlow, WrapFlow, MergeFlow, VolumePreservingWrapFlow
from ..nn.flow.crd_transform.ic import GlobalInternalCoordinateTransformation
from ..nn.flow.inverted import InverseFlow
from ..nn.flow.cdf import CDFTransform
Expand All @@ -20,7 +23,8 @@
from ..distribution.product import ProductDistribution, ProductEnergy
from ..bg import BoltzmannGenerator
from .tensor_info import (
TensorInfo, BONDS, ANGLES, TORSIONS, FIXED, ORIGIN, ROTATION, AUGMENTED, TARGET
TensorInfo, BONDS, ANGLES, TORSIONS, FIXED, ORIGIN, ROTATION, AUGMENTED, TARGET,
ShapeDictionary
)
from .conditioner_factory import make_conditioners
from .transformer_factory import make_transformer
Expand Down Expand Up @@ -509,6 +513,94 @@ def add_constrain_chirality(self, halpha_torsion_indices, right_handed=False, to
affine = TorchTransform(torch.distributions.AffineTransform(loc=loc, scale=scale), 1)
return self.add_layer(affine, what=(torsions, ))

@contextlib.contextmanager
def volume_preserving_block(
self,
volume_sink: TensorInfo,
condition_on_dlogp: bool = True,
exclude_inputs_from_conditioner: Sequence[TensorInfo] = tuple(),
exclude_outputs_from_conditioner: Sequence[TensorInfo] = tuple(),
**conditioner_kwargs
):
"""Context manager for volume-preserving co-transforms.
A volume-preserving block can contain arbitrary (primary) transforms.
Every volume change (`dlogp != 0`) in this block will be sucked up
by a volume sink tensor so that the total `dlogp` vanishes.

Parameters
----------
volume_sink
The field that acts as a volume sink.
condition_on_dlogp
Whether to condition the affine transform on dlogp of the primary transform
exclude_inputs_from_conditioner
Input tensors that are not passed to the conditioner of the affine co-layer.
exclude_outputs_from_conditioner
Output tensors that are not passed to the conditioner of the affine co-layer.

Notes
-----
It is paramount that the volume sink field is not used by any transform in the block.

Examples
--------
>>> from bgflow import
>>> builder = BoltzmannGeneratorBuilder(...)
>>> with builder.volume_preserving_block(volume_sink=AUGMENTED):
>>> builder.add_condition(BONDS, on=(ANGLES, TORSIONS))

No matter the transform used in the coupling layer, this block will
have vanishing `dlogp` in total.
"""
previous_layer = len(self.layers)
input_shape_dict = copy.copy(self.current_dims)
volume_sink_index_before = self.current_dims.index(volume_sink)
yield
# wrap layers that have been added in context
volume_sink_index_after = self.current_dims.index(volume_sink)
wrapped_flow = SequentialFlow(self.layers[previous_layer:])
self.layers = self.layers[:previous_layer]

# make conditioner inputs
cond_indices = []
cond_names = []
coflow_input_shapes = ShapeDictionary()

dlogp_info = TensorInfo("dlogp", is_circular=False)
coflow_input_shapes[dlogp_info] = (1,)
if condition_on_dlogp:
cond_indices.append(0)
cond_names.append(dlogp_info)
for i, (info, shape) in enumerate(input_shape_dict.items(), start=1):
coflow_input_shapes[info] = shape
if info not in (*exclude_inputs_from_conditioner, volume_sink):
cond_indices.append(i)
cond_names.append(info)
for i, (info, shape) in enumerate(self.current_dims.items(), start=1+len(input_shape_dict)):
info_out = info._replace(name=info.name+"_out")
coflow_input_shapes[info_out] = shape
if info not in (*exclude_outputs_from_conditioner, volume_sink):
cond_indices.append(i)
cond_names.append(info_out)

affine_conditioners = make_conditioners(
transformer_type=AffineTransformer,
what=_tuple(volume_sink),
on=cond_names,
shape_info=coflow_input_shapes,
**conditioner_kwargs
)
affine_conditioners = {name: net.to(**self.ctx) for name, net in affine_conditioners.items()}
volume_preserver = VolumePreservingWrapFlow(
flow=wrapped_flow,
volume_sink_index=volume_sink_index_before,
out_volume_sink_index=volume_sink_index_after,
cond_indices=cond_indices,
**affine_conditioners
)

self.add_layer(volume_preserver)

def _add_to_param_groups(self, parameters, param_groups):
parameters = list(parameters)
for group in param_groups:
Expand Down
3 changes: 3 additions & 0 deletions bgflow/nn/flow/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@


import torch


Expand Down Expand Up @@ -31,3 +33,4 @@ def forward(self, *xs, inverse=False, **kwargs):
return self._inverse(*xs, **kwargs)
else:
return self._forward(*xs, **kwargs)

105 changes: 104 additions & 1 deletion bgflow/nn/flow/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

from .base import Flow
from .inverted import InverseFlow
from .transformer.affine import AffineTransformer

__all__ = ["SplitFlow", "MergeFlow", "SwapFlow", "CouplingFlow", "WrapFlow", "SetConstantFlow"]
__all__ = [
"SplitFlow", "MergeFlow", "SwapFlow", "CouplingFlow",
"WrapFlow", "SetConstantFlow", "VolumePreservingWrapFlow"
]


class SplitFlow(Flow):
Expand Down Expand Up @@ -207,6 +211,7 @@ def _forward(self, *xs, **kwargs):
inp = (xs[i] for i in self._indices)
output = [xs[i] for i in range(len(xs)) if i not in self._indices]
*yi, dlogp = self._flow(*inp, **kwargs)
assert len(yi) == len(self._out_indices)
for i in self._argsort_out_indices:
index = self._out_indices[i]
output.insert(index, yi[i])
Expand All @@ -216,11 +221,109 @@ def _inverse(self, *xs, **kwargs):
inp = (xs[i] for i in self._out_indices)
output = [xs[i] for i in range(len(xs)) if i not in self._out_indices]
*yi, dlogp = self._flow(*inp, inverse=True, **kwargs)
assert len(yi) == len(self._indices)
for i in self._argsort_indices:
index = self._indices[i]
output.insert(index, yi[i])
return (*tuple(output), dlogp)

def output_index(self, input_index):
"""Output index of a non-transformed input."""
if input_index in self._indices:
raise ValueError("output_index is only defined for non-transformed inputs")
n_flow_non_inputs_before_sink = sum(i not in self._indices for i in range(input_index))
output_index = n_flow_non_inputs_before_sink
for i_out in sorted(self._out_indices):
if i_out <= output_index:
# "insert before"
output_index += 1
else:
# everything else is inserted after
break
return output_index


class VolumePreservingWrapFlow(Flow):
def __init__(
self,
flow: Flow,
volume_sink_index: int,
out_volume_sink_index: int,
cond_indices: Sequence[int],
shift_transformation: torch.nn.Module = None,
scale_transformation: torch.nn.Module = None
):
"""Volume-preserving wrap layer.

This layer operates on two or more input tensors.

One of these tensors (as indexed by `volume_sink_index` and `out_volume_sink_index`)
acts as a volume sink, while the others are transformed by a flow.
Concretely, after applying the flow, an affine layer is applied to the volume sink
in such a way that the volume change of this affine "co-transform" (`co_dlogp`) counteracts
the volume change (`dlogp`) of the primary flow, `dlogp + co_dlogp = 0`.

The parameters of the co-transform (shift and scale)
can be conditioned on dlogp as well as the inputs and outputs
of the primary flow.

It is important that the primary transform does not use the volume sink in any way,
neither transform it nor condition on it.

Parameters
----------
flow
The primary transform.
volume_sink_index
Input index of the volume sink tensor in the forward pass.
out_volume_sink_index
Input index of the volume sink tensor in the inverse pass.
cond_indices : Sequence[int]
This is a bit tricky. These indices refer to elements of the list
`[dlogp, *inputs, *outputs]`.
shift_transformation : torch.nn.Module, optional
scale_transformation : torch.nn.Module, optional
"""
super().__init__()
self.flow = flow
self.volume_sink_index = volume_sink_index
self.out_volume_sink_index = out_volume_sink_index
co_transform = AffineTransformer(
shift_transformation=shift_transformation,
scale_transformation=scale_transformation,
preserve_volume=True,
)
self.co_flow = CouplingFlow(
transformer=co_transform,
transformed_indices=(1 + self.volume_sink_index, ),
cond_indices=cond_indices,
cat_dim=-1
)
assert all(i != 1 + self.volume_sink_index for i in cond_indices)

def _forward(self, *xs, **kwargs):
*ys, dlogp = self.flow.forward(*xs, **kwargs)
co_out, co_dlogp = self._apply_coflow(dlogp, xs, ys, inverse=False)
ys[self.out_volume_sink_index] = co_out
return (*ys, dlogp + co_dlogp)

def _inverse(self, *ys, **kwargs):
*xs, dlogp = self.flow.forward(*ys, inverse=True, **kwargs)
co_out, co_dlogp = self._apply_coflow(forward_dlogp=-dlogp, xs=xs, ys=ys, inverse=True)
xs[self.volume_sink_index] = co_out
return (*xs, dlogp + co_dlogp)

def _apply_coflow(self, forward_dlogp, xs, ys, inverse):
assert torch.allclose(xs[self.volume_sink_index], ys[self.out_volume_sink_index])
coflow_in = [
forward_dlogp,
*[x for i, x in enumerate(xs)],
*[y for i, y in enumerate(ys)]
]
target_dlogp = forward_dlogp if inverse else -forward_dlogp
*co_out, co_dlogp = self.co_flow.forward(*coflow_in, target_dlogp=target_dlogp, inverse=inverse)
return co_out[1 + self.volume_sink_index], co_dlogp


class SetConstantFlow(Flow):
"""A flow that sets some inputs constant in the forward direction and removes them in the inverse.
Expand Down
26 changes: 20 additions & 6 deletions bgflow/nn/flow/transformer/affine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@

from typing import Union

import warnings
import torch
import numpy as np

from .base import Transformer

Expand Down Expand Up @@ -32,7 +37,7 @@ def __init__(
self._preserve_volume = preserve_volume
self._is_circular = is_circular

def _get_mu_and_log_sigma(self, x, y, *cond):
def _get_mu_and_log_sigma(self, x, y, *cond, target_dlogp: Union[float, torch.Tensor] = None):
if self._shift_transformation is not None:
mu = self._shift_transformation(x, *cond)
else:
Expand All @@ -42,13 +47,22 @@ def _get_mu_and_log_sigma(self, x, y, *cond):
log_sigma = torch.tanh(self._scale_transformation(x, *cond))
log_sigma = log_sigma * alpha
if self._preserve_volume:
log_sigma = log_sigma - log_sigma.mean(dim=-1, keepdim=True)
target_dlogp = 0.0 if target_dlogp is None else target_dlogp
target_scale = target_dlogp / np.prod(log_sigma[0].shape)
log_sigma = (
log_sigma
- log_sigma.mean(dim=-1, keepdim=True)
+ target_scale * torch.ones_like(log_sigma)
)
else:
if target_dlogp is not None:
warnings.warn("target_dlogp is only effective is self.preserve_volume is enabled.")
else:
log_sigma = torch.zeros_like(y).to(x)
return mu, log_sigma

def _forward(self, x, y, *cond, **kwargs):
mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond)
def _forward(self, x, y, *cond, target_dlogp=None, **kwargs):
mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond, target_dlogp=target_dlogp)
assert mu.shape[-1] == y.shape[-1]
assert log_sigma.shape[-1] == y.shape[-1]
sigma = torch.exp(log_sigma)
Expand All @@ -58,8 +72,8 @@ def _forward(self, x, y, *cond, **kwargs):
y = y % 1.0
return y, dlogp

def _inverse(self, x, y, *cond, **kwargs):
mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond)
def _inverse(self, x, y, *cond, target_dlogp=None, **kwargs):
mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond, target_dlogp=None if target_dlogp is None else -target_dlogp)
assert mu.shape[-1] == y.shape[-1]
assert log_sigma.shape[-1] == y.shape[-1]
sigma_inv = torch.exp(-log_sigma)
Expand Down
30 changes: 30 additions & 0 deletions tests/factory/test_generator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,33 @@ def test_constrain_chirality(ala2, ctx):
b, a, t, *_ = crd_transform.forward(samples)
assert torch.all(t[:, chiral_torsions] >= 0.5)
assert torch.all(t[:, chiral_torsions] <= 1.0)


def test_volume_preserving_context(ctx):
shape_info = ShapeDictionary()
shape_info[BONDS] = (10, )
shape_info[ANGLES] = (20, )
builder = BoltzmannGeneratorBuilder(
shape_info,
**ctx
)
builder.targets[BONDS] = NormalDistribution(10, torch.zeros(10, **ctx))
builder.targets[ANGLES] = NormalDistribution(20, torch.zeros(20, **ctx))
# transform some fields
with builder.volume_preserving_block(volume_sink=ANGLES):
builder.add_layer(
CDFTransform(
TruncatedNormalDistribution(
torch.zeros(10, **ctx),
lower_bound=-torch.tensor(np.infty, **ctx)
),
),
what=[BONDS],
inverse=True,
param_groups=("group1", )
)
generator = builder.build_generator()
results = generator.sample(10, with_latent=True, with_dlogp=True, with_energy=True)
x, z, dlogp, energy = results[:2], results[2], results[3], results[4]
assert torch.allclose(dlogp, torch.zeros_like(dlogp), atol=1e-5)
assert torch.allclose(generator.energy(*x), generator.prior.energy(*z), atol=1e-5)
Loading