Source code for surjectors._src.bijectors.affine_masked_coupling
from collections.abc import Callable
import distrax
from jax import numpy as jnp
from surjectors._src.bijectors.masked_coupling import MaskedCoupling
from surjectors._src.distributions.transformed_distribution import Array
# pylint: disable=too-many-arguments, arguments-renamed,too-many-ancestors
[docs]
class AffineMaskedCoupling(MaskedCoupling):
"""An affine masked coupling layer.
Args:
mask: a boolean mask of length n_dim. A value
of True indicates that the corresponding input remains unchanged
conditioner: a function that computes the parameters of the inner
bijector
event_ndims: the number of array dimensions the bijector operates on
inner_event_ndims: the number of array dimensions the inner bijector
operates on
References:
.. [1] Dinh, Laurent, et al. "Density estimation using RealNVP".
International Conference on Learning Representations, 2017.
Examples:
>>> import distrax
>>> from surjectors import AffineMaskedCoupling
>>> from surjectors.nn import make_mlp
>>> from surjectors.util import make_alternating_binary_mask
>>>
>>> layer = MaskedCoupling(
... mask=make_alternating_binary_mask(10, True),
... conditioner=make_mlp([8, 8, 10 * 2]),
... )
"""
def __init__(
self,
mask: Array,
conditioner: Callable,
event_ndims: int | None = None,
inner_event_ndims: int = 0,
):
def _bijector_fn(params):
means, log_scales = jnp.split(params, 2, -1)
return distrax.ScalarAffine(means, jnp.exp(log_scales))
super().__init__(
mask, conditioner, _bijector_fn, event_ndims, inner_event_ndims
)