Source code for surjectors._src.conditioners.mlp
import haiku as hk
import jax
from jax import numpy as jnp
# type: ignore[B008]
[docs]
def make_mlp(
dims: tuple[int, ...],
activation=jax.nn.swish,
w_init=hk.initializers.TruncatedNormal(stddev=0.01),
b_init=jnp.zeros,
):
"""Create a conditioner network based on an MLP.
Args:
dims: dimensions of hidden layers and last layer
activation: a JAX activation function
w_init: a haiku initializer
b_init: a haiku initializer
Returns:
a transformable haiku neural network module
"""
return hk.nets.MLP(
output_sizes=dims,
w_init=w_init,
b_init=b_init,
activation=activation,
)