surjectors
#
Normalizing flows have from a computational perspective three components:
A base distribution for which we use the probability distributions from Distrax.
A forward transformation \(f\) whose Jacobian determinant can be evaluated efficiently. These are the bijectors and surjectors below.
A transformed distribution that represents the pushforward from a base distribution to the distribution induced by the transformation.
Hence, every normalizing flow can be composed by defining these three components. See an example below.
>>> import distrax
>>> from jax import random as jr, numpy as jnp
>>> from surjectors import Slice, LULinear, Chain
>>> from surjectors import TransformedDistribution
>>>
>>> def decoder_fn(n_dim):
>>> def _fn(z):
>>> params = make_mlp([4, 4, n_dim * 2])(z)
>>> mu, log_scale = jnp.split(params, 2, -1)
>>> return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))
>>> return _fn
>>>
>>> base_distribution = distrax.Normal(jno.zeros(5), jnp.ones(1))
>>> flow = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
>>> pushforward = TransformedDistribution(base_distribution, flow)
Regardless of how the chain of transformations (called flow
above) is defined,
each pushforward has access to four methods sample
, sample_and_log_prob
log_prob
, and inverse_and_log_prob
.
The exact method declarations can be found in the API below.
General#
|
Distribution of a random variable transformed by a function. |
|
Chain of normalizing flows. |
TransformedDistribution#
- class surjectors.TransformedDistribution(base_distribution, transform)[source]#
Distribution of a random variable transformed by a function.
Can be used to define a pushforward measure.
- Parameters:
base_distribution (Distribution) – a distribution object
transform (Surjector) – some transformation
Examples
>>> import distrax >>> from jax import numpy as jnp >>> from surjectors import Slice, Chain, TransformedDistribution >>> >>> a = Slice(10) >>> b = Slice(5) >>> ab = Chain([a, b]) >>> >>> TransformedDistribution( >>> distrax.Normal(jnp.zeros(5), jnp.ones(5)), >>> Chain([a, b]) >>> )
- log_prob(y, x=None)[source]#
Calculate the log probability of an event conditional on another.
- Parameters:
y (Array) – event for which the log probability is computed
x (Array) – optional event that is used to condition
- Returns:
array of floats of log probabilities
- inverse_and_log_prob(y, x=None)[source]#
Compute the inverse transformation and its log probability.
- Parameters:
y (Array) – event for which the inverse and log probability is computed
x (Array) – optional event that is used to condition
- Returns:
tuple of two arrays of floats. The first one is the inverse transformation, the second one is the log probability
- sample(sample_shape=(), x=None)[source]#
Sample an event.
- Parameters:
sample_shape – the size of the sample to be drawn
x (Array) – optional event that is used to condition the samples. If x is given sample_shape is ignored
- Returns:
a sample from the transformed distribution
- sample_and_log_prob(sample_shape=(), x=None)[source]#
Sample an event and compute its log probability.
- Parameters:
sample_shape – the size of the sample to be drawn
x (Array) – optional event that is used to condition the samples. If x is given sample_shape is ignored
- Returns:
tuple of two arrays of floats. The first one is the drawn sample transformation, the second one is its log probability
Chain#
- class surjectors.Chain(transforms)[source]#
Chain of normalizing flows.
Can be used to concatenate several normalizing flows together.
- Parameters:
transforms (list[Transform]) – a list of transformations, such as bijections or surjections
Examples
>>> from surjectors import Slice, Chain >>> a = Slice(10) >>> b = Slice(5) >>> ab = Chain([a, b])
Bijective layers#
|
A masked autoregressive layer. |
|
An affine masked autoregressive layer. |
|
A masked coupling layer. |
|
An affine masked coupling layer. |
|
A rational quadratic spline masked coupling layer. |
|
Permute the dimensions of a vector. |
Autoregressive bijections#
- class surjectors.MaskedAutoregressive(conditioner, bijector_fn, event_ndims=1, inner_event_ndims=0)[source]#
A masked autoregressive layer.
- Parameters:
conditioner (MADE) – a MADE network
bijector_fn (Callable) – a callable that returns the inner bijector that will be used to transform the input
event_ndims (int) – the number of array dimensions the bijector operates on
inner_event_ndims (int) – tthe number of array dimensions the bijector operates on
References
Examples
>>> import distrax >>> from surjectors import MaskedAutoregressive >>> from surjectors.util import unstack >>> >>> def bijector_fn(params): >>> means, log_scales = unstack(params, -1) >>> return distrax.ScalarAffine(means, jnp.exp(log_scales)) >>> >>> layer = MaskedAutoregressive( >>> conditioner=MADE(10, [8, 8], 2), >>> bijector_fn=bijector_fn >>> )
- class surjectors.AffineMaskedAutoregressive(conditioner, event_ndims=1, inner_event_ndims=0)[source]#
An affine masked autoregressive layer.
- Parameters:
conditioner (MADE) – a MADE network
event_ndims (int) – the number of array dimensions the bijector operates on
inner_event_ndims (int) – tthe number of array dimensions the bijector operates on
References
[1] Papamakarios, George, et al. “Masked Autoregressive Flow for Density Estimation”. Advances in Neural Information Processing Systems, 2017.
Examples
>>> import distrax >>> from surjectors import AffineMaskedAutoregressive >>> >>> layer = AffineMaskedAutoregressive( >>> conditioner=MADE(10, [8, 8], 2), >>> )
Coupling bijections#
- class surjectors.MaskedCoupling(mask, conditioner, bijector_fn, event_ndims=None, inner_event_ndims=0)[source]#
A masked coupling layer.
- Parameters:
mask (Array) – a boolean mask of length n_dim. A value of True indicates that the corresponding input remains unchanged
conditioner (Callable) – a function that computes the parameters of the inner bijector
bijector_fn (Callable) – a callable that returns the inner bijector that will be used to transform the input
event_ndims (int | None) – the number of array dimensions the bijector operates on
inner_event_ndims (int) – the number of array dimensions the inner bijector operates on
Examples
>>> import distrax >>> from surjectors import MaskedCoupling >>> from surjectors.nn import make_mlp >>> from surjectors.util import make_alternating_binary_mask >>> >>> def bijector_fn(params): >>> means, log_scales = jnp.split(params, 2, -1) >>> return distrax.ScalarAffine(means, jnp.exp(log_scales)) >>> >>> layer = MaskedCoupling( >>> mask=make_alternating_binary_mask(10, True), >>> bijector_fn=bijector_fn, >>> conditioner=make_mlp([8, 8, 10 * 2]), >>> )
References
[1] Dinh, Laurent, et al. “Density estimation using RealNVP”. International Conference on Learning Representations, 2017.
- class surjectors.AffineMaskedCoupling(mask, conditioner, event_ndims=None, inner_event_ndims=0)[source]#
An affine masked coupling layer.
- Parameters:
mask (Array) – a boolean mask of length n_dim. A value of True indicates that the corresponding input remains unchanged
conditioner (Callable) – a function that computes the parameters of the inner bijector
event_ndims (int | None) – the number of array dimensions the bijector operates on
inner_event_ndims (int) – the number of array dimensions the inner bijector operates on
References
[1] Dinh, Laurent, et al. “Density estimation using RealNVP”. International Conference on Learning Representations, 2017.
Examples
>>> import distrax >>> from surjectors import AffineMaskedCoupling >>> from surjectors.nn import make_mlp >>> from surjectors.util import make_alternating_binary_mask >>> >>> layer = MaskedCoupling( >>> mask=make_alternating_binary_mask(10, True), >>> conditioner=make_mlp([8, 8, 10 * 2]), >>> )
- class surjectors.RationalQuadraticSplineMaskedCoupling(mask, conditioner, range_min, range_max, event_ndims=None, inner_event_ndims=0)[source]#
A rational quadratic spline masked coupling layer.
References
[1] Dinh, Laurent, et al. “Density estimation using RealNVP”. International Conference on Learning Representations, 2017.
[2] Durkan, Conor, et al. “Neural Spline Flows”. Advances in Neural Information Processing Systems, 2019.
Examples
>>> import distrax >>> from surjectors import RationalQuadraticSplineMaskedCoupling >>> from surjectors.nn import make_mlp >>> from surjectors.util import make_alternating_binary_mask >>> >>> layer = RationalQuadraticSplineMaskedCoupling( >>> mask=make_alternating_binary_mask(10, True), >>> conditioner=make_mlp([8, 8, 10 * 2]), >>> range_min=-1.0, >>> range_max=1.0 >>> )
- Parameters:
mask (Array) –
conditioner (Callable) –
range_min (float) –
range_max (float) –
event_ndims (int | None) –
inner_event_ndims (int) –
Other bijections#
- class surjectors.Permutation(permutation, event_ndims_in)[source]#
Permute the dimensions of a vector.
- Parameters:
permutation – a vector of integer indexes representing the order of the elements
event_ndims_in (int) – number of input event dimensions
Examples
>>> from surjectors import Permutation >>> from jax import numpy as jnp >>> >>> order = jnp.arange(10) >>> perm = Permutation(order, 1)
Inference surjection layers#
|
A masked coupling inference funnel. |
|
A masked coupling inference funnel that uses an affine transformation. |
A masked coupling inference funnel that uses a rational quatratic spline. |
|
|
A masked autoregressive funnel layer. |
A masked affine autoregressive funnel layer. |
|
|
A masked autoregressive inference funnel that uses RQ-NSFs. |
|
An bijection based on the LU composition. |
|
A multilayer perceptron inference funnel. |
|
A slice funnel. |
Coupling inference surjections#
- class surjectors.MaskedCouplingInferenceFunnel(n_keep, decoder, conditioner, bijector_fn)[source]#
A masked coupling inference funnel.
The MaskedCouplingInferenceFunnel is a coupling funnel, i.e., dimensionality reducing transformation, that uses a masking mechanism as in MaskedCouplingI. Its inner bijectors needs to be specified in comparison to ASffineMaskedCouplingInferenceFunnel and RationalQuadraticSplineMaskedCouplingInferenceFunnel.
- Parameters:
n_keep (int) – number of dimensions to keep
decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called
conditioner (Callable) – a conditioning neural network
bijector_fn (Callable) – an inner bijector function to be used
Examples
>>> import distrax >>> from surjectors import MaskedCouplingInferenceFunnel >>> from surjectors.nn import make_mlp >>> >>> def decoder_fn(n_dim): >>> def _fn(z): >>> params = make_mlp([4, 4, n_dim * 2])(z) >>> mu, log_scale = jnp.split(params, 2, -1) >>> return distrax.Independent( >>> distrax.Normal(mu, jnp.exp(log_scale)) >>> ) >>> return _fn >>> >>> def bijector_fn(params): >>> shift, log_scale = jnp.split(params, 2, -1) >>> return distrax.ScalarAffine(shift, jnp.exp(log_scale)) >>> >>> layer = MaskedCouplingInferenceFunnel( >>> n_keep=10, >>> decoder=decoder_fn(10), >>> conditioner=make_mlp([4, 4, 10 * 2]), >>> bijector_fn=bijector_fn >>> )
References
[1] Klein, Samuel, et al. “Funnels: Exact maximum likelihood with dimensionality reduction”. Workshop on Bayesian Deep Learning, Advances in Neural Information Processing Systems, 2021.
[2] Dinh, Laurent, et al. “Density estimation using RealNVP”. International Conference on Learning Representations, 2017.
- class surjectors.AffineMaskedCouplingInferenceFunnel(n_keep, decoder, conditioner)[source]#
A masked coupling inference funnel that uses an affine transformation.
- Parameters:
n_keep (int) – number of dimensions to keep
decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called
conditioner (Callable) – a conditioning neural network
Examples
>>> import distrax >>> from jax import numpy as jnp >>> from surjectors import AffineMaskedCouplingInferenceFunnel >>> from surjectors.nn import make_mlp >>> >>> def decoder_fn(n_dim): >>> def _fn(z): >>> params = make_mlp([4, 4, n_dim * 2])(z) >>> mu, log_scale = jnp.split(params, 2, -1) >>> return distrax.Independent( >>> distrax.Normal(mu, jnp.exp(log_scale)) >>> ) >>> return _fn >>> >>> layer = AffineMaskedCouplingInferenceFunnel( >>> n_keep=10, >>> decoder=decoder_fn(10), >>> conditioner=make_mlp([4, 4, 10 * 2])(z), >>> )
References
[1] Klein, Samuel, et al. “Funnels: Exact maximum likelihood with dimensionality reduction”. Workshop on Bayesian Deep Learning, Advances in Neural Information Processing Systems, 2021.
[2] Dinh, Laurent, et al. “Density estimation using RealNVP”. International Conference on Learning Representations, 2017.
- class surjectors.RationalQuadraticSplineMaskedCouplingInferenceFunnel(*args, **kwargs)[source]#
A masked coupling inference funnel that uses a rational quatratic spline.
- Parameters:
n_keep – number of dimensions to keep
decoder – a callable that returns a conditional probabiltiy distribution when called
conditioner – a conditioning neural network
range_min – minimum range of the spline
range_max – maximum range of the spline
Examples
>>> import distrax >>> from jax import numpy as jnp >>> from surjectors import >>> RationalQuadraticSplineMaskedCouplingInferenceFunnel >>> from surjectors.nn import make_mlp >>> >>> def decoder_fn(n_dim): >>> def _fn(z): >>> params = make_mlp([4, 4, n_dim * 2])(z) >>> mu, log_scale = jnp.split(params, 2, -1) >>> return distrax.Independent( >>> distrax.Normal(mu, jnp.exp(log_scale)) >>> ) >>> return _fn >>> >>> layer = RationalQuadraticSplineMaskedCouplingInferenceFunnel( >>> n_keep=10, >>> decoder=decoder_fn(10), >>> conditioner=make_mlp([4, 4, 10 * 2])(z), >>> )
References
[1] Klein, Samuel, et al. “Funnels: Exact maximum likelihood with dimensionality reduction”. Workshop on Bayesian Deep Learning, Advances in Neural Information Processing Systems, 2021.
[2] Durkan, Conor, et al. “Neural Spline Flows”. Advances in Neural Information Processing Systems, 2019.
[3] Dinh, Laurent, et al. “Density estimation using RealNVP”. International Conference on Learning Representations, 2017.
Autoregressive inference surjections#
- class surjectors.MaskedAutoregressiveInferenceFunnel(n_keep, decoder, conditioner, bijector_fn)[source]#
A masked autoregressive funnel layer.
The MaskedAutoregressiveInferenceFunnel is an autoregressive funnel, i.e., dimensionality reducing transformation, that uses a masking mechanism as in MaskedAutoegressive. Its inner bijectors needs to be specified in comparison to AffineMaskedAutoregressiveInferenceFunnel and RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel.
- Parameters:
n_keep (int) – number of dimensions to keep
decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called
conditioner (MADE) – a MADE neural network
bijector_fn (Callable) – an inner bijector function to be used
Examples
>>> import distrax >>> from surjectors import MaskedAutoregressiveInferenceFunnel >>> from surjectors.nn import MADE, make_mlp >>> from surjectors.util import unstack >>> >>> def decoder_fn(n_dim): >>> def _fn(z): >>> params = make_mlp([4, 4, n_dim * 2])(z) >>> mu, log_scale = jnp.split(params, 2, -1) >>> return distrax.Independent( >>> distrax.Normal(mu, jnp.exp(log_scale)) >>> ) >>> return _fn >>> >>> def bijector_fn(params: Array): >>> shift, log_scale = unstack(params, axis=-1) >>> return distrax.ScalarAffine(shift, jnp.exp(log_scale)) >>> >>> layer = MaskedAutoregressiveInferenceFunnel( >>> n_keep=10, >>> decoder=decoder_fn(10), >>> conditioner=MADE(10, [8, 8], 2), >>> bijector_fn=bijector_fn >>> )
References
[1] Klein, Samuel, et al. “Funnels: Exact maximum likelihood with dimensionality reduction”. Workshop on Bayesian Deep Learning, Advances in Neural Information Processing Systems, 2021.
[2] Papamakarios, George, et al. “Masked Autoregressive Flow for Density Estimation”. Advances in Neural Information Processing Systems, 2017.
- class surjectors.AffineMaskedAutoregressiveInferenceFunnel(n_keep, decoder, conditioner)[source]#
A masked affine autoregressive funnel layer.
The AffineMaskedAutoregressiveInferenceFunnel is an autoregressive funnel, i.e., dimensionality reducing transformation, that uses an affine transformation from data to latent space using a masking mechanism as in MaskedAutoegressive.
- Parameters:
n_keep (int) – number of dimensions to keep
decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called
conditioner (MADE) – a MADE neural network
Examples
>>> import distrax >>> from surjectors import AffineMaskedAutoregressiveInferenceFunnel >>> from surjectors.nn import MADE, make_mlp >>> from surjectors.util import unstack >>> >>> def decoder_fn(n_dim): >>> def _fn(z): >>> params = make_mlp([4, 4, n_dim * 2])(z) >>> mu, log_scale = jnp.split(params, 2, -1) >>> return distrax.Independent( >>> distrax.Normal(mu, jnp.exp(log_scale)) >>> ) >>> return _fn >>> >>> layer = AffineMaskedAutoregressiveInferenceFunnel( >>> n_keep=10, >>> decoder=decoder_fn(10), >>> conditioner=MADE(10, [8, 8], 2), >>> )
References
[1] Klein, Samuel, et al. “Funnels: Exact maximum likelihood with dimensionality reduction”. Workshop on Bayesian Deep Learning, Advances in Neural Information Processing Systems, 2021.
[2] Papamakarios, George, et al. “Masked Autoregressive Flow for Density Estimation”. Advances in Neural Information Processing Systems, 2017.
- class surjectors.RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel(n_keep, decoder, conditioner, range_min, range_max)[source]#
A masked autoregressive inference funnel that uses RQ-NSFs.
- Parameters:
n_keep (int) – number of dimensions to keep
decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called
conditioner (MADE) – a conditioning neural network
range_min (float) – minimum range of the spline
range_max (float) – maximum range of the spline
Examples
>>> import distrax >>> from jax import numpy as jnp >>> from surjectors import >>> RationalQuadraticSplineMaskedCouplingInferenceFunnel >>> from surjectors.nn import make_mlp >>> >>> def decoder_fn(n_dim): >>> def _fn(z): >>> params = make_mlp([4, 4, n_dim * 2])(z) >>> mu, log_scale = jnp.split(params, 2, -1) >>> return distrax.Independent( >>> distrax.Normal(mu, jnp.exp(log_scale)) >>> ) >>> return _fn >>> >>> layer = RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel( >>> n_keep=10, >>> decoder=decoder_fn(10), >>> conditioner=MADE(10, [8, 8], 2), >>> )
References
[1] Klein, Samuel, et al. “Funnels: Exact maximum likelihood with dimensionality reduction”. Workshop on Bayesian Deep Learning, Advances in Neural Information Processing Systems, 2021.
[2] Durkan, Conor, et al. “Neural Spline Flows”. Advances in Neural Information Processing Systems, 2019.
[3] Papamakarios, George, et al. “Masked Autoregressive Flow for Density Estimation”. Advances in Neural Information Processing Systems, 2017.
Other inference surjections#
- class surjectors.LULinear(n_dimension, with_bias=False, dtype=<class 'jax.numpy.float32'>)[source]#
An bijection based on the LU composition.
- Parameters:
n_dimension – number of dimensions to keep
with_bias – use a bias term or not
dtype – parameter dtype
References
[1] Oliva, Junier, et al. “Transformation Autoregressive Networks”. Proceedings of the 35th International Conference on Machine Learning, 2018.
Examples
>>> from surjectors import LULinear >>> layer = LULinear(10)
- class surjectors.MLPInferenceFunnel(n_keep, decoder)[source]#
A multilayer perceptron inference funnel.
- Parameters:
n_keep (int) – number of dimensions to keep
decoder (Callable) – a conditional probability function
Examples
>>> import distrax >>> from surjectors import MLPInferenceFunnel >>> from surjectors.nn import make_mlp >>> >>> def decoder_fn(n_dim): >>> def _fn(z): >>> params = make_mlp([4, 4, n_dim * 2])(z) >>> mu, log_scale = jnp.split(params, 2, -1) >>> return distrax.Independent( >>> distrax.Normal(mu, jnp.exp(log_scale)) >>> ) >>> return _fn >>> >>> decoder = decoder_fn(5) >>> a = MLPInferenceFunnel(10, decoder)
References
[1] Klein, Samuel, et al. “Funnels: Exact maximum likelihood with dimensionality reduction”. Workshop on Bayesian Deep Learning, Advances in Neural Information Processing Systems, 2021.
- class surjectors.Slice(n_keep, decoder)[source]#
A slice funnel.
- Parameters:
n_keep (int) – number if dimensions to keep
decoder (Callable) – callable
Examples
>>> import distrax >>> from surjectors import Slice >>> from surjectors.nn import make_mlp >>> >>> def decoder_fn(n_dim): >>> def _fn(z): >>> params = make_mlp([4, 4, n_dim * 2])(z) >>> mu, log_scale = jnp.split(params, 2, -1) >>> return distrax.Independent( >>> distrax.Normal(mu, jnp.exp(log_scale)) >>> ) >>> return _fn >>> >>> layer = Slice(10, decoder_fn(10))
References
[1] Nielsen, Didrik, et al. “SurVAE Flows: Surjections to Bridge the Gap between VAEs and Flows”. Advances in Neural Information Processing Systems, 2020.