Source code for surjectors._src.bijectors.masked_autoregressive
from collections.abc import Callable
from distrax._src.utils import math
from jax import numpy as jnp
from surjectors._src.bijectors.bijector import Bijector
from surjectors._src.conditioners.nn.made import MADE
# pylint: disable=too-many-arguments,arguments-renamed
[docs]
class MaskedAutoregressive(Bijector):
"""A masked autoregressive layer.
Args:
conditioner: a MADE network
bijector_fn: a callable that returns the inner bijector that will
be used to transform the input
event_ndims: the number of array dimensions the bijector operates on
inner_event_ndims: tthe number of array dimensions the bijector
operates on
References:
.. [1] Papamakarios, George, et al. "Masked Autoregressive Flow for
Density Estimation". Advances in Neural Information Processing
Systems, 2017.
Examples:
>>> import distrax
>>> from surjectors import MaskedAutoregressive
>>> from surjectors.util import unstack
>>>
>>> def bijector_fn(params):
>>> means, log_scales = unstack(params, -1)
>>> return distrax.ScalarAffine(means, jnp.exp(log_scales))
>>>
>>> layer = MaskedAutoregressive(
>>> conditioner=MADE(10, [8, 8], 2),
>>> bijector_fn=bijector_fn
>>> )
"""
def __init__(
self,
conditioner: MADE,
bijector_fn: Callable,
event_ndims: int = 1,
inner_event_ndims: int = 0,
):
if event_ndims is not None and event_ndims < inner_event_ndims:
raise ValueError(
f"'event_ndims={event_ndims}' should be at least as"
f" large as 'inner_event_ndims={inner_event_ndims}'."
)
if not isinstance(conditioner, MADE):
raise ValueError(
"conditioner should be a MADE when used MaskedAutoregressive flow"
)
self._event_ndims = event_ndims
self._inner_event_ndims = inner_event_ndims
self.conditioner = conditioner
self._inner_bijector = bijector_fn
def _forward_and_likelihood_contribution(self, z, x=None, **kwargs):
y = jnp.zeros_like(z)
for _ in jnp.arange(z.shape[-1]):
params = self.conditioner(y, x)
y, log_det = self._inner_bijector(params).forward_and_log_det(z)
log_det = math.sum_last(
log_det, self._event_ndims - self._inner_event_ndims
)
return y, log_det
def _inverse_and_likelihood_contribution(self, y, x=None, **kwargs):
params = self.conditioner(y, x)
z, log_det = self._inner_bijector(params).inverse_and_log_det(y)
log_det = math.sum_last(
log_det, self._event_ndims - self._inner_event_ndims
)
return z, log_det