Source code for surjectors._src.bijectors.affine_masked_autoregressive
import distrax
from jax import numpy as jnp
from surjectors._src.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors._src.conditioners.nn.made import MADE
from surjectors.util import unstack
# pylint: disable=too-many-arguments,arguments-renamed
[docs]
class AffineMaskedAutoregressive(MaskedAutoregressive):
"""An affine masked autoregressive layer.
Args:
conditioner: a MADE network
event_ndims: the number of array dimensions the bijector operates on
inner_event_ndims: tthe number of array dimensions the bijector
operates on
References:
.. [1] Papamakarios, George, et al. "Masked Autoregressive Flow for
Density Estimation". Advances in Neural Information Processing
Systems, 2017.
Examples:
>>> import haiku as hk
>>> from jax import random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
>>> from surjectors import AffineMaskedAutoregressive, TransformedDistribution
>>> @hk.without_apply_rng
... @hk.transform
... def fn(inputs):
... base_distribution = tfd.Independent(
... tfd.Normal(jnp.zeros(10), jnp.ones(10)),
... reinterpreted_batch_ndims=1,
... )
... td = TransformedDistribution(
... base_distribution,
... AffineMaskedAutoregressive(MADE(10, (64, 64), 2))
... )
... return td.log_prob(inputs)
>>> data = jr.normal(jr.PRNGKey(1), shape=(10, 10))
>>> params = fn.init(jr.key(0), data)
>>> lps = fn.apply(params, data)
"""
def __init__(
self,
conditioner: MADE,
event_ndims: int = 1,
inner_event_ndims: int = 0,
):
def bijector_fn(params):
means, log_scales = unstack(params, -1)
return distrax.ScalarAffine(means, jnp.exp(log_scales))
super().__init__(conditioner, bijector_fn, event_ndims, inner_event_ndims)