surjectors#


Normalizing flows have from a computational perspective three components:

  • A base distribution for which we use the probability distributions from Distrax.

  • A forward transformation \(f\) whose Jacobian determinant can be evaluated efficiently. These are the bijectors and surjectors below.

  • A transformed distribution that represents the pushforward from a base distribution to the distribution induced by the transformation.

Hence, every normalizing flow can be composed by defining these three components. See an example below.

>>> import distrax
>>> from jax import random as jr, numpy as jnp
>>> from surjectors import Slice, LULinear, Chain
>>> from surjectors import TransformedDistribution
>>>
>>> def decoder_fn(n_dim):
>>>     def _fn(z):
>>>         params = make_mlp([4, 4, n_dim * 2])(z)
>>>         mu, log_scale = jnp.split(params, 2, -1)
>>>         return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))
>>>     return _fn
>>>
>>> base_distribution = distrax.Normal(jno.zeros(5), jnp.ones(1))
>>> flow = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
>>> pushforward = TransformedDistribution(base_distribution, flow)

Regardless of how the chain of transformations (called flow above) is defined, each pushforward has access to four methods sample, sample_and_log_problog_prob, and inverse_and_log_prob.

The exact method declarations can be found in the API below.

General#

TransformedDistribution(base_distribution, ...)

Distribution of a random variable transformed by a function.

Chain(transforms)

Chain of normalizing flows.

TransformedDistribution#

class surjectors.TransformedDistribution(base_distribution, transform)[source]#

Distribution of a random variable transformed by a function.

Can be used to define a pushforward measure.

Parameters:
  • base_distribution (Distribution) – a distribution object

  • transform (Surjector) – some transformation

Examples

>>> import distrax
>>> from jax import numpy as jnp
>>> from surjectors import Slice, Chain, TransformedDistribution
>>>
>>> a = Slice(10)
>>> b = Slice(5)
>>> ab = Chain([a, b])
>>>
>>> TransformedDistribution(
>>>     distrax.Normal(jnp.zeros(5), jnp.ones(5)),
>>>     Chain([a, b])
>>> )
log_prob(y, x=None)[source]#

Calculate the log probability of an event conditional on another.

Parameters:
  • y (Array) – event for which the log probability is computed

  • x (Array) – optional event that is used to condition

Returns:

array of floats of log probabilities

inverse_and_log_prob(y, x=None)[source]#

Compute the inverse transformation and its log probability.

Parameters:
  • y (Array) – event for which the inverse and log probability is computed

  • x (Array) – optional event that is used to condition

Returns:

tuple of two arrays of floats. The first one is the inverse transformation, the second one is the log probability

sample(sample_shape=(), x=None)[source]#

Sample an event.

Parameters:
  • sample_shape – the size of the sample to be drawn

  • x (Array) – optional event that is used to condition the samples. If x is given sample_shape is ignored

Returns:

a sample from the transformed distribution

sample_and_log_prob(sample_shape=(), x=None)[source]#

Sample an event and compute its log probability.

Parameters:
  • sample_shape – the size of the sample to be drawn

  • x (Array) – optional event that is used to condition the samples. If x is given sample_shape is ignored

Returns:

tuple of two arrays of floats. The first one is the drawn sample transformation, the second one is its log probability

Chain#

class surjectors.Chain(transforms)[source]#

Chain of normalizing flows.

Can be used to concatenate several normalizing flows together.

Parameters:

transforms (list[Transform]) – a list of transformations, such as bijections or surjections

Examples

>>> from surjectors import Slice, Chain
>>> a = Slice(10)
>>> b = Slice(5)
>>> ab = Chain([a, b])

Bijective layers#

MaskedAutoregressive(conditioner, bijector_fn)

A masked autoregressive layer.

AffineMaskedAutoregressive(conditioner[, ...])

An affine masked autoregressive layer.

MaskedCoupling(mask, conditioner, bijector_fn)

A masked coupling layer.

AffineMaskedCoupling(mask, conditioner[, ...])

An affine masked coupling layer.

RationalQuadraticSplineMaskedCoupling(mask, ...)

A rational quadratic spline masked coupling layer.

Permutation(permutation, event_ndims_in)

Permute the dimensions of a vector.

Autoregressive bijections#

class surjectors.MaskedAutoregressive(conditioner, bijector_fn, event_ndims=1, inner_event_ndims=0)[source]#

A masked autoregressive layer.

Parameters:
  • conditioner (MADE) – a MADE network

  • bijector_fn (Callable) – a callable that returns the inner bijector that will be used to transform the input

  • event_ndims (int) – the number of array dimensions the bijector operates on

  • inner_event_ndims (int) – tthe number of array dimensions the bijector operates on

References

Examples

>>> import distrax
>>> from surjectors import MaskedAutoregressive
>>> from surjectors.util import unstack
>>>
>>> def bijector_fn(params):
>>>     means, log_scales = unstack(params, -1)
>>>     return distrax.ScalarAffine(means, jnp.exp(log_scales))
>>>
>>> layer = MaskedAutoregressive(
>>>     conditioner=MADE(10, [8, 8], 2),
>>>     bijector_fn=bijector_fn
>>> )
class surjectors.AffineMaskedAutoregressive(conditioner, event_ndims=1, inner_event_ndims=0)[source]#

An affine masked autoregressive layer.

Parameters:
  • conditioner (MADE) – a MADE network

  • event_ndims (int) – the number of array dimensions the bijector operates on

  • inner_event_ndims (int) – tthe number of array dimensions the bijector operates on

References

Examples

>>> import distrax
>>> from surjectors import AffineMaskedAutoregressive
>>>
>>> layer = AffineMaskedAutoregressive(
>>>     conditioner=MADE(10, [8, 8], 2),
>>> )

Coupling bijections#

class surjectors.MaskedCoupling(mask, conditioner, bijector_fn, event_ndims=None, inner_event_ndims=0)[source]#

A masked coupling layer.

Parameters:
  • mask (Array) – a boolean mask of length n_dim. A value of True indicates that the corresponding input remains unchanged

  • conditioner (Callable) – a function that computes the parameters of the inner bijector

  • bijector_fn (Callable) – a callable that returns the inner bijector that will be used to transform the input

  • event_ndims (int | None) – the number of array dimensions the bijector operates on

  • inner_event_ndims (int) – the number of array dimensions the inner bijector operates on

Examples

>>> import distrax
>>> from surjectors import MaskedCoupling
>>> from surjectors.nn import make_mlp
>>> from surjectors.util import make_alternating_binary_mask
>>>
>>> def bijector_fn(params):
>>>     means, log_scales = jnp.split(params, 2, -1)
>>>     return distrax.ScalarAffine(means, jnp.exp(log_scales))
>>>
>>> layer = MaskedCoupling(
>>>     mask=make_alternating_binary_mask(10, True),
>>>     bijector_fn=bijector_fn,
>>>     conditioner=make_mlp([8, 8, 10 * 2]),
>>> )

References

class surjectors.AffineMaskedCoupling(mask, conditioner, event_ndims=None, inner_event_ndims=0)[source]#

An affine masked coupling layer.

Parameters:
  • mask (Array) – a boolean mask of length n_dim. A value of True indicates that the corresponding input remains unchanged

  • conditioner (Callable) – a function that computes the parameters of the inner bijector

  • event_ndims (int | None) – the number of array dimensions the bijector operates on

  • inner_event_ndims (int) – the number of array dimensions the inner bijector operates on

References

Examples

>>> import distrax
>>> from surjectors import AffineMaskedCoupling
>>> from surjectors.nn import make_mlp
>>> from surjectors.util import make_alternating_binary_mask
>>>
>>> layer = MaskedCoupling(
>>>     mask=make_alternating_binary_mask(10, True),
>>>     conditioner=make_mlp([8, 8, 10 * 2]),
>>> )
class surjectors.RationalQuadraticSplineMaskedCoupling(mask, conditioner, range_min, range_max, event_ndims=None, inner_event_ndims=0)[source]#

A rational quadratic spline masked coupling layer.

References

Examples

>>> import distrax
>>> from surjectors import RationalQuadraticSplineMaskedCoupling
>>> from surjectors.nn import make_mlp
>>> from surjectors.util import make_alternating_binary_mask
>>>
>>> layer = RationalQuadraticSplineMaskedCoupling(
>>>     mask=make_alternating_binary_mask(10, True),
>>>     conditioner=make_mlp([8, 8, 10 * 2]),
>>>     range_min=-1.0,
>>>     range_max=1.0
>>> )
Parameters:
  • mask (Array) –

  • conditioner (Callable) –

  • range_min (float) –

  • range_max (float) –

  • event_ndims (int | None) –

  • inner_event_ndims (int) –

Other bijections#

class surjectors.Permutation(permutation, event_ndims_in)[source]#

Permute the dimensions of a vector.

Parameters:
  • permutation – a vector of integer indexes representing the order of the elements

  • event_ndims_in (int) – number of input event dimensions

Examples

>>> from surjectors import Permutation
>>> from jax import numpy as jnp
>>>
>>> order = jnp.arange(10)
>>> perm = Permutation(order, 1)

Inference surjection layers#

MaskedCouplingInferenceFunnel(n_keep, ...)

A masked coupling inference funnel.

AffineMaskedCouplingInferenceFunnel(n_keep, ...)

A masked coupling inference funnel that uses an affine transformation.

RationalQuadraticSplineMaskedCouplingInferenceFunnel(...)

A masked coupling inference funnel that uses a rational quatratic spline.

MaskedAutoregressiveInferenceFunnel(n_keep, ...)

A masked autoregressive funnel layer.

AffineMaskedAutoregressiveInferenceFunnel(...)

A masked affine autoregressive funnel layer.

RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel(...)

A masked autoregressive inference funnel that uses RQ-NSFs.

LULinear(n_dimension[, with_bias, dtype])

An bijection based on the LU composition.

MLPInferenceFunnel(n_keep, decoder)

A multilayer perceptron inference funnel.

Slice(n_keep, decoder)

A slice funnel.

Coupling inference surjections#

class surjectors.MaskedCouplingInferenceFunnel(n_keep, decoder, conditioner, bijector_fn)[source]#

A masked coupling inference funnel.

The MaskedCouplingInferenceFunnel is a coupling funnel, i.e., dimensionality reducing transformation, that uses a masking mechanism as in MaskedCouplingI. Its inner bijectors needs to be specified in comparison to ASffineMaskedCouplingInferenceFunnel and RationalQuadraticSplineMaskedCouplingInferenceFunnel.

Parameters:
  • n_keep (int) – number of dimensions to keep

  • decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called

  • conditioner (Callable) – a conditioning neural network

  • bijector_fn (Callable) – an inner bijector function to be used

Examples

>>> import distrax
>>> from surjectors import MaskedCouplingInferenceFunnel
>>> from surjectors.nn import make_mlp
>>>
>>> def decoder_fn(n_dim):
>>>     def _fn(z):
>>>         params = make_mlp([4, 4, n_dim * 2])(z)
>>>         mu, log_scale = jnp.split(params, 2, -1)
>>>         return distrax.Independent(
>>>             distrax.Normal(mu, jnp.exp(log_scale))
>>>         )
>>>     return _fn
>>>
>>> def bijector_fn(params):
>>>     shift, log_scale = jnp.split(params, 2, -1)
>>>     return distrax.ScalarAffine(shift, jnp.exp(log_scale))
>>>
>>> layer = MaskedCouplingInferenceFunnel(
>>>     n_keep=10,
>>>     decoder=decoder_fn(10),
>>>     conditioner=make_mlp([4, 4, 10 * 2]),
>>>     bijector_fn=bijector_fn
>>> )

References

class surjectors.AffineMaskedCouplingInferenceFunnel(n_keep, decoder, conditioner)[source]#

A masked coupling inference funnel that uses an affine transformation.

Parameters:
  • n_keep (int) – number of dimensions to keep

  • decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called

  • conditioner (Callable) – a conditioning neural network

Examples

>>> import distrax
>>> from jax import numpy as jnp
>>> from surjectors import AffineMaskedCouplingInferenceFunnel
>>> from surjectors.nn import make_mlp
>>>
>>> def decoder_fn(n_dim):
>>>     def _fn(z):
>>>         params = make_mlp([4, 4, n_dim * 2])(z)
>>>         mu, log_scale = jnp.split(params, 2, -1)
>>>         return distrax.Independent(
>>>             distrax.Normal(mu, jnp.exp(log_scale))
>>>         )
>>>     return _fn
>>>
>>> layer = AffineMaskedCouplingInferenceFunnel(
>>>     n_keep=10,
>>>     decoder=decoder_fn(10),
>>>     conditioner=make_mlp([4, 4, 10 * 2])(z),
>>> )

References

class surjectors.RationalQuadraticSplineMaskedCouplingInferenceFunnel(*args, **kwargs)[source]#

A masked coupling inference funnel that uses a rational quatratic spline.

Parameters:
  • n_keep – number of dimensions to keep

  • decoder – a callable that returns a conditional probabiltiy distribution when called

  • conditioner – a conditioning neural network

  • range_min – minimum range of the spline

  • range_max – maximum range of the spline

Examples

>>> import distrax
>>> from jax import numpy as jnp
>>> from surjectors import         >>>     RationalQuadraticSplineMaskedCouplingInferenceFunnel
>>> from surjectors.nn import make_mlp
>>>
>>> def decoder_fn(n_dim):
>>>     def _fn(z):
>>>         params = make_mlp([4, 4, n_dim * 2])(z)
>>>         mu, log_scale = jnp.split(params, 2, -1)
>>>         return distrax.Independent(
>>>             distrax.Normal(mu, jnp.exp(log_scale))
>>>         )
>>>     return _fn
>>>
>>> layer = RationalQuadraticSplineMaskedCouplingInferenceFunnel(
>>>     n_keep=10,
>>>     decoder=decoder_fn(10),
>>>     conditioner=make_mlp([4, 4, 10 * 2])(z),
>>> )

References

Autoregressive inference surjections#

class surjectors.MaskedAutoregressiveInferenceFunnel(n_keep, decoder, conditioner, bijector_fn)[source]#

A masked autoregressive funnel layer.

The MaskedAutoregressiveInferenceFunnel is an autoregressive funnel, i.e., dimensionality reducing transformation, that uses a masking mechanism as in MaskedAutoegressive. Its inner bijectors needs to be specified in comparison to AffineMaskedAutoregressiveInferenceFunnel and RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel.

Parameters:
  • n_keep (int) – number of dimensions to keep

  • decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called

  • conditioner (MADE) – a MADE neural network

  • bijector_fn (Callable) – an inner bijector function to be used

Examples

>>> import distrax
>>> from surjectors import MaskedAutoregressiveInferenceFunnel
>>> from surjectors.nn import MADE, make_mlp
>>> from surjectors.util import unstack
>>>
>>> def decoder_fn(n_dim):
>>>     def _fn(z):
>>>         params = make_mlp([4, 4, n_dim * 2])(z)
>>>         mu, log_scale = jnp.split(params, 2, -1)
>>>         return distrax.Independent(
>>>             distrax.Normal(mu, jnp.exp(log_scale))
>>>         )
>>>     return _fn
>>>
>>> def bijector_fn(params: Array):
>>>     shift, log_scale = unstack(params, axis=-1)
>>>     return distrax.ScalarAffine(shift, jnp.exp(log_scale))
>>>
>>> layer = MaskedAutoregressiveInferenceFunnel(
>>>     n_keep=10,
>>>     decoder=decoder_fn(10),
>>>     conditioner=MADE(10, [8, 8], 2),
>>>     bijector_fn=bijector_fn
>>> )

References

class surjectors.AffineMaskedAutoregressiveInferenceFunnel(n_keep, decoder, conditioner)[source]#

A masked affine autoregressive funnel layer.

The AffineMaskedAutoregressiveInferenceFunnel is an autoregressive funnel, i.e., dimensionality reducing transformation, that uses an affine transformation from data to latent space using a masking mechanism as in MaskedAutoegressive.

Parameters:
  • n_keep (int) – number of dimensions to keep

  • decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called

  • conditioner (MADE) – a MADE neural network

Examples

>>> import distrax
>>> from surjectors import AffineMaskedAutoregressiveInferenceFunnel
>>> from surjectors.nn import MADE, make_mlp
>>> from surjectors.util import unstack
>>>
>>> def decoder_fn(n_dim):
>>>     def _fn(z):
>>>         params = make_mlp([4, 4, n_dim * 2])(z)
>>>         mu, log_scale = jnp.split(params, 2, -1)
>>>         return distrax.Independent(
>>>             distrax.Normal(mu, jnp.exp(log_scale))
>>>         )
>>>     return _fn
>>>
>>> layer = AffineMaskedAutoregressiveInferenceFunnel(
>>>     n_keep=10,
>>>     decoder=decoder_fn(10),
>>>     conditioner=MADE(10, [8, 8], 2),
>>> )

References

class surjectors.RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel(n_keep, decoder, conditioner, range_min, range_max)[source]#

A masked autoregressive inference funnel that uses RQ-NSFs.

Parameters:
  • n_keep (int) – number of dimensions to keep

  • decoder (Callable) – a callable that returns a conditional probabiltiy distribution when called

  • conditioner (MADE) – a conditioning neural network

  • range_min (float) – minimum range of the spline

  • range_max (float) – maximum range of the spline

Examples

>>> import distrax
>>> from jax import numpy as jnp
>>> from surjectors import         >>>     RationalQuadraticSplineMaskedCouplingInferenceFunnel
>>> from surjectors.nn import make_mlp
>>>
>>> def decoder_fn(n_dim):
>>>     def _fn(z):
>>>         params = make_mlp([4, 4, n_dim * 2])(z)
>>>         mu, log_scale = jnp.split(params, 2, -1)
>>>         return distrax.Independent(
>>>             distrax.Normal(mu, jnp.exp(log_scale))
>>>         )
>>>     return _fn
>>>
>>> layer = RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel(
>>>     n_keep=10,
>>>     decoder=decoder_fn(10),
>>>     conditioner=MADE(10, [8, 8], 2),
>>> )

References

Other inference surjections#

class surjectors.LULinear(n_dimension, with_bias=False, dtype=<class 'jax.numpy.float32'>)[source]#

An bijection based on the LU composition.

Parameters:
  • n_dimension – number of dimensions to keep

  • with_bias – use a bias term or not

  • dtype – parameter dtype

References

Examples

>>> from surjectors import LULinear
>>> layer = LULinear(10)
class surjectors.MLPInferenceFunnel(n_keep, decoder)[source]#

A multilayer perceptron inference funnel.

Parameters:
  • n_keep (int) – number of dimensions to keep

  • decoder (Callable) – a conditional probability function

Examples

>>> import distrax
>>> from surjectors import MLPInferenceFunnel
>>> from surjectors.nn import make_mlp
>>>
>>> def decoder_fn(n_dim):
>>>     def _fn(z):
>>>         params = make_mlp([4, 4, n_dim * 2])(z)
>>>         mu, log_scale = jnp.split(params, 2, -1)
>>>         return distrax.Independent(
>>>             distrax.Normal(mu, jnp.exp(log_scale))
>>>         )
>>>     return _fn
>>>
>>> decoder = decoder_fn(5)
>>> a = MLPInferenceFunnel(10, decoder)

References

class surjectors.Slice(n_keep, decoder)[source]#

A slice funnel.

Parameters:
  • n_keep (int) – number if dimensions to keep

  • decoder (Callable) – callable

Examples

>>> import distrax
>>> from surjectors import Slice
>>> from surjectors.nn import make_mlp
>>>
>>> def decoder_fn(n_dim):
>>>     def _fn(z):
>>>         params = make_mlp([4, 4, n_dim * 2])(z)
>>>         mu, log_scale = jnp.split(params, 2, -1)
>>>         return distrax.Independent(
>>>             distrax.Normal(mu, jnp.exp(log_scale))
>>>         )
>>>     return _fn
>>>
>>> layer = Slice(10, decoder_fn(10))

References