surjectors.nn#


surjectors.nn contains utility functions and classes to construct neural networks and several network architectures to build normalizing flows such as density networks.

MADE(input_size, hidden_layer_sizes, n_params)

Masked Autoregressive Density Estimator.

make_mlp(dims[, activation, w_init, b_init])

Create a conditioner network based on an MLP.

make_transformer(output_size[, num_heads, ...])

Create a conditioner network based on a transformer.

class surjectors.nn.MADE(input_size, hidden_layer_sizes, n_params, w_init=None, b_init=None, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#

Masked Autoregressive Density Estimator.

Passing a value through a MADE will output a tensor of shape […, input_size, n_params]

Examples

>>> from surjectors.nn import MADE
>>> made = MADE(10, [32, 32], 2)
Parameters:
  • input_size (int) –

  • hidden_layer_sizes (list[int] | tuple[int]) –

  • n_params (int) –

  • w_init (Callable[[Sequence[int], Any], Array] | None) –

  • b_init (Callable[[Sequence[int], Any], Array] | None) –

  • activation (Callable[[Array], Array]) –

__call__(y, x=None)[source]#

Apply the MADE network.

Parameters:
  • y (Array) – input to be transformed

  • x (Array) – conditioning variable

Returns:

the transformed value

surjectors.nn.make_mlp(dims, activation=<function gelu>, w_init=<haiku._src.initializers.TruncatedNormal object>, b_init=<function zeros>)[source]#

Create a conditioner network based on an MLP.

Parameters:
  • dims – dimensions of hidden layers and last layer

  • activation – a JAX activation function

  • w_init – a haiku initializer

  • b_init – a haiku initializer

Returns:

a transformable haiku neural network module

surjectors.nn.make_transformer(output_size, num_heads=4, num_layers=4, key_size=32, dropout_rate=0.1, widening_factor=4)[source]#

Create a conditioner network based on a transformer.

Parameters:
  • output_size – output size of the last layer

  • num_heads – number of heads of the attention

  • num_layers – number of layers

  • key_size – size of the key

  • dropout_rate – rate of dropout

  • widening_factor – factor by which MLP after attention is widened

Returns:

a transformable haiku neural network module