π Welcome to Surjectors!#
Surjection layers for density estimation with normalizing flows
Surjectors is a light-weight library for density estimation using inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality. Surjectors builds on Distrax and Haiku and is fully compatible with both of them.
Surjectors makes use of
Haikuβs module system for neural networks,
Distrax for probability distributions and some base bijectors,
Optax for gradient-based optimization,
JAX for autodiff and XLA computation.
Example#
You can, for instance, construct a simple normalizing flow like this:
import distrax
import haiku as hk
from jax import numpy as jnp, random as jr
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp
def decoder_fn(n_dim):
def _fn(z):
params = make_mlp([32, 32, n_dim * 2])(z)
means, log_scales = jnp.split(params, 2, -1)
return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
return _fn
@hk.without_apply_rng
@hk.transform
def flow(x):
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
)
transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
pushforward = TransformedDistribution(base_distribution, transform)
return pushforward.log_prob(x)
x = jr.normal(jr.PRNGKey(1), (1, 10))
params = flow.init(jr.PRNGKey(2), x)
lp = flow.apply(params, x)
The flow is constructed using three objects: a base distribution, a transformation, and a transformed distribution.
Installation#
To install from PyPI, call:
pip install surjectors
To install the latest GitHub <RELEASE>, just call the following on the command line:
pip install git+https://github.com/dirmeier/surjectors@<RELEASE>
See also the installation instructions for JAX, if you plan to use Surjectors on GPU/TPU.
Contributing#
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled βgood first issueβ.
In order to contribute:
Clone
surjectorsand installuvfrom here,install all dependencies using
uv sync,create a new branch locally
git checkout -b feature/my-new-featureorgit checkout -b issue/fixes-bug,implement your contribution and ideally a test case,
test it by calling
make format,make lintsandmake testson the (Unix) command line,submit a PR π
Citing Surjectors#
@article{dirmeier2024surjectors,
author = {Simon Dirmeier},
title = {Surjectors: surjection layers for density estimation with normalizing flows},
year = {2024},
journal = {Journal of Open Source Software},
publisher = {The Open Journal},
volume = {9},
number = {94},
pages = {6188},
doi = {10.21105/joss.06188}
}
License#
Surjectors is licensed under the Apache 2.0 License.