Source code for surjectors._src.bijectors.rq_masked_coupling
from collections.abc import Callable
import distrax
from surjectors._src.bijectors.masked_coupling import MaskedCoupling
from surjectors._src.distributions.transformed_distribution import Array
# ruff: noqa: PLR0913
[docs]
class RationalQuadraticSplineMaskedCoupling(MaskedCoupling):
"""A rational quadratic spline masked coupling layer.
References:
.. [1] Dinh, Laurent, et al. "Density estimation using RealNVP".
International Conference on Learning Representations, 2017.
.. [2] Durkan, Conor, et al. "Neural Spline Flows".
Advances in Neural Information Processing Systems, 2019.
Examples:
>>> import distrax
>>> from surjectors import RationalQuadraticSplineMaskedCoupling
>>> from surjectors.nn import make_mlp
>>> from surjectors.util import make_alternating_binary_mask
>>>
>>> layer = RationalQuadraticSplineMaskedCoupling(
>>> mask=make_alternating_binary_mask(10, True),
>>> conditioner=make_mlp([8, 8, 10 * 2]),
>>> range_min=-1.0,
>>> range_max=1.0
>>> )
"""
def __init__(
self,
mask: Array,
conditioner: Callable,
range_min: float,
range_max: float,
event_ndims: int | None = None,
inner_event_ndims: int = 0,
):
"""Construct a rational quadratic spline masked coupling layer.
Args:
mask: a boolean mask of length n_dim. A value
of True indicates that the corresponding input remains unchanged
conditioner: a function that computes the parameters of the inner
bijector
range_min: minimum range of the spline
range_max: maximum range of the spline
event_ndims: the number of array dimensions the bijector operates on
inner_event_ndims: the number of array dimensions the inner bijector
operates on
"""
self.range_min = range_min
self.range_max = range_max
def _bijector_fn(params: Array):
return distrax.RationalQuadraticSpline(
params, self.range_min, self.range_max
)
super().__init__(
mask, conditioner, _bijector_fn, event_ndims, inner_event_ndims
)