Source code for surjectors._src.conditioners.nn.made
from collections.abc import Callable
import haiku as hk
import jax
from jax import Array
from jax import numpy as jnp
from tensorflow_probability.substrates.jax.bijectors.masked_autoregressive import ( # noqa: E501
_make_dense_autoregressive_masks,
)
from surjectors._src.conditioners.nn.masked_linear import MaskedLinear
# ruff: noqa: PLR0913
[docs]
class MADE(hk.Module):
"""Masked Autoregressive Density Estimator.
Passing a value through a MADE will output a tensor of shape
[..., input_size, n_params]
Examples:
>>> from surjectors.nn import MADE
>>> made = MADE(10, (32, 32), 2)
"""
def __init__(
self,
input_size: int,
hidden_layer_sizes: tuple[int, ...],
n_params: int,
w_init: hk.initializers.Initializer | None = None,
b_init: hk.initializers.Initializer | None = None,
activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
):
"""Construct a MADE network.
Args:
input_size: number of input features
hidden_layer_sizes: list/tuple of ints describing the number of
nodes in the hidden layers
n_params: number of output parameters. For instance, if used as
a conditioner of an affine bijector should be 2 (mean and scale)
w_init: a Haiku initializer
b_init: a Haiku initializer
activation: n activation function
"""
super().__init__()
self.input_size = input_size
self.output_sizes = hidden_layer_sizes
self.n_params = n_params
self.w_init = w_init
self.b_init = b_init
self.activation = activation
masks = _make_dense_autoregressive_masks(
n_params, self.input_size, self.output_sizes
)
layers = []
for mask in masks:
layers.append(
MaskedLinear(
mask=mask.astype(jnp.float_),
w_init=w_init,
b_init=b_init,
)
)
self.layers = tuple(layers)
[docs]
def __call__(self, y: Array, x: Array = None):
"""Apply the MADE network.
Args:
y: input to be transformed
x: conditioning variable
Returns:
the transformed value
"""
output = self.layers[0](y)
if x is not None:
context = hk.Linear(
self.output_sizes[0], w_init=self.w_init, b_init=self.b_init
)(x)
output += context
output = self.activation(output)
for i, layer in enumerate(self.layers[1:]):
output = layer(output)
if i < len(self.layers[1:]) - 1:
output = self.activation(output)
output = hk.Reshape((self.input_size, self.n_params))(output)
return output