surjectors.nn
#
surjectors.nn
contains utility functions and classes to construct neural
networks and several network architectures to build normalizing flows such as
density networks.
|
Masked Autoregressive Density Estimator. |
|
Create a conditioner network based on an MLP. |
|
Create a conditioner network based on a transformer. |
- class surjectors.nn.MADE(input_size, hidden_layer_sizes, n_params, w_init=None, b_init=None, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#
Masked Autoregressive Density Estimator.
Passing a value through a MADE will output a tensor of shape […, input_size, n_params]
Examples
>>> from surjectors.nn import MADE >>> made = MADE(10, [32, 32], 2)
- Parameters:
input_size (int) –
hidden_layer_sizes (list[int] | tuple[int]) –
n_params (int) –
w_init (Callable[[Sequence[int], Any], Array] | None) –
b_init (Callable[[Sequence[int], Any], Array] | None) –
activation (Callable[[Array], Array]) –
- surjectors.nn.make_mlp(dims, activation=<function gelu>, w_init=<haiku._src.initializers.TruncatedNormal object>, b_init=<function zeros>)[source]#
Create a conditioner network based on an MLP.
- Parameters:
dims – dimensions of hidden layers and last layer
activation – a JAX activation function
w_init – a haiku initializer
b_init – a haiku initializer
- Returns:
a transformable haiku neural network module
- surjectors.nn.make_transformer(output_size, num_heads=4, num_layers=4, key_size=32, dropout_rate=0.1, widening_factor=4)[source]#
Create a conditioner network based on a transformer.
- Parameters:
output_size – output size of the last layer
num_heads – number of heads of the attention
num_layers – number of layers
key_size – size of the key
dropout_rate – rate of dropout
widening_factor – factor by which MLP after attention is widened
- Returns:
a transformable haiku neural network module