Source code for surjectors._src.bijectors.masked_coupling

from collections.abc import Callable

import distrax
from distrax._src.utils import math
from jax import numpy as jnp

from surjectors._src.bijectors.bijector import Bijector
from surjectors._src.distributions.transformed_distribution import Array


# ruff: noqa: PLR0913
[docs] class MaskedCoupling(Bijector, distrax.MaskedCoupling): """A masked coupling layer. Args: mask: a boolean mask of length n_dim. A value of True indicates that the corresponding input remains unchanged conditioner: a function that computes the parameters of the inner bijector bijector_fn: a callable that returns the inner bijector that will be used to transform the input event_ndims: the number of array dimensions the bijector operates on inner_event_ndims: the number of array dimensions the inner bijector operates on Examples: >>> import distrax >>> from surjectors import MaskedCoupling >>> from surjectors.nn import make_mlp >>> from surjectors.util import make_alternating_binary_mask >>> >>> def bijector_fn(params): >>> means, log_scales = jnp.split(params, 2, -1) >>> return distrax.ScalarAffine(means, jnp.exp(log_scales)) >>> >>> layer = MaskedCoupling( >>> mask=make_alternating_binary_mask(10, True), >>> bijector_fn=bijector_fn, >>> conditioner=make_mlp([8, 8, 10 * 2]), >>> ) References: .. [1] Dinh, Laurent, et al. "Density estimation using RealNVP". International Conference on Learning Representations, 2017. """ def __init__( self, mask: Array, conditioner: Callable, bijector_fn: Callable, event_ndims: int | None = None, inner_event_ndims: int = 0, ): super().__init__( mask, conditioner, bijector_fn, event_ndims, inner_event_ndims ) def _forward_and_likelihood_contribution(self, z, x=None, **kwargs): self._check_forward_input_shape(z) masked_z = jnp.where(self._event_mask, z, 0.0) if x is not None: masked_z = jnp.concatenate([masked_z, x], axis=-1) params = self._conditioner(masked_z) y0, log_d = self._inner_bijector(params).forward_and_log_det(z) y = jnp.where(self._event_mask, z, y0) logdet = math.sum_last( jnp.where(self._mask, 0.0, log_d), self._event_ndims - self._inner_event_ndims, ) return y, logdet def _inverse_and_likelihood_contribution(self, y, x=None, **kwargs): self._check_inverse_input_shape(y) masked_y = jnp.where(self._event_mask, y, 0.0) if x is not None: masked_y = jnp.concatenate([masked_y, x], axis=-1) params = self._conditioner(masked_y) z0, log_d = self._inner_bijector(params).inverse_and_log_det(y) z = jnp.where(self._event_mask, y, z0) logdet = math.sum_last( jnp.where(self._mask, 0.0, log_d), self._event_ndims - self._inner_event_ndims, ) return z, logdet