Dimensionality reduction with surjections#
Surjective normalizing flows use dimensionality-reducing transformations instead of dimensionality-preserving bijective ones. Below we implement several surjective normalizing for a density estimation problem and compare them to a conventional bijective flow.
Interactive online version:
[1]:
import distrax
import haiku as hk
import jax
import numpy as np
import optax
import pandas as pd
from collections import namedtuple
from jax import jit
from jax import numpy as jnp
from jax import random as jr
from tqdm import tqdm
from surjectors import (
Chain,
LULinear,
MaskedCoupling,
MaskedCouplingInferenceFunnel,
MLPInferenceFunnel,
TransformedDistribution,
)
from surjectors.nn import make_mlp
from surjectors.util import (
as_batch_iterator,
make_alternating_binary_mask,
)
We define a training function first that we can use for all density estimation tasks below.
[2]:
def train(rng_key, data, model, n_iter=1000):
# convert the data set to an iterator
batch_key, rng_key = jr.split(rng_key)
train_iter = as_batch_iterator(batch_key, data, 100, True)
# initialize the model
init_key, rng_key = jr.split(rng_key)
params = model.init(init_key, method="log_prob", **train_iter(0))
# create an optimizer
optimizer = optax.adam(1e-4)
state = optimizer.init(params)
@jit
# gradient step
def step(params, state, **batch):
def loss_fn(params):
lp = model.apply(params, None, method="log_prob", **batch)
return -jnp.mean(lp)
loss, grads = jax.value_and_grad(loss_fn)(params)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return loss, new_params, new_state
losses = np.zeros(n_iter)
# training loop
for i in tqdm(range(n_iter)):
train_loss = 0.0
# iterate over batches
for j in range(train_iter.num_batches):
batch = train_iter(j)
batch_loss, params, state = step(params, state, **batch)
train_loss += batch_loss
losses[i] = train_loss
return params, losses
Data#
We simulate data from a factor model for testing. The data can should easily embedded in a lower dimensional space via a linear transformation.
[3]:
rng_key_seq = hk.PRNGSequence(0)
[4]:
n_train, n_test = 1000, 200
n = n_train + n_test
p_data, p_latent = 20, 5
z = jr.normal(next(rng_key_seq), (n, p_latent))
W = jr.normal(next(rng_key_seq), (p_data, p_latent)) * 0.1
y = (W @ z.T).T + jr.normal(next(rng_key_seq), (n, p_data)) * 0.1
A bijective baseline#
We start with a simple baseline: a masked coupling flow with rational quadratic splines as transforms. The data is not dimensionality-reducing and hence tries to estimate the density on the \(20\)-dimensional space.
A RQ splince flow requires defining ranges for which we use the lower and upper bounds of the data
[5]:
range_min, range_max = float(np.min(y)), float(np.max(y))
Next we define the conditioner function.
[6]:
def make_rq_conditioner(event_shape, hidden_sizes, n_bins):
n_params = 3 * n_bins + 1
return hk.Sequential(
[
make_mlp(hidden_sizes + [event_shape * n_params]),
hk.Reshape((event_shape,) + (n_params,), preserve_dims=-1),
]
)
We create a baseine that uses five masked coupling flows.
[7]:
def make_baseline(n_dimensions):
def flow(**kwargs):
def bijector_fn(params):
return distrax.RationalQuadraticSpline(
params, range_min=range_min, range_max=range_max
)
layers = []
for i in range(5):
layer = MaskedCoupling(
mask=make_alternating_binary_mask(n_dimensions, i % 2 == 0),
conditioner=make_rq_conditioner(n_dimensions, [128, 128], 4),
bijector_fn=bijector_fn,
)
layers.append(layer)
transform = Chain(layers)
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(n_dimensions), jnp.ones(n_dimensions)),
reinterpreted_batch_ndims=1,
)
pushforward = TransformedDistribution(base_distribution, transform)
return pushforward(**kwargs)
td = hk.transform(flow)
return td
Training of the baseline is done as follows:
[8]:
baseline = make_baseline(p_data)
data = namedtuple("named_dataset", "y")(y[:n_train])
params_baseline, _ = train(next(rng_key_seq), data, baseline)
100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1000/1000 [01:42<00:00, 9.72it/s]
A surjective MLP funnel#
As a first surjective flow, we implement a MLPInferenceFunnel
. The surjection uses a LU decomposition as inner bijector and a conditional probability density parameterized by an MLP as a decoder. We again use a flow of five layers. The first two and the last two are dimensionality-preserving LULinear
bijections. The layer in the middle is a dimensionality-reducing funnel.
[9]:
def make_surjective_mlp_funnel(n_dimensions):
def flow(**kwargs):
def decoder_fn(n_dim):
def fn(z):
params = make_mlp([32, 32, n_dim * 2])(z)
mu, log_scale = jnp.split(params, 2, -1)
return distrax.Independent(
distrax.Normal(mu, jnp.exp(log_scale))
)
return fn
n_dim = n_dimensions
layers = []
for i in range(5):
if i == 2:
layer = MLPInferenceFunnel(
n_keep=int(n_dim / 2), decoder=decoder_fn(int(n_dim / 2))
)
n_dim = int(n_dim / 2)
else:
layer = LULinear(n_dim)
layers.append(layer)
transform = Chain(layers)
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)),
reinterpreted_batch_ndims=1,
)
pushforward = TransformedDistribution(base_distribution, transform)
return pushforward(**kwargs)
td = hk.transform(flow)
return td
[10]:
surjective_mlp_funnel = make_surjective_mlp_funnel(p_data)
data = namedtuple("named_dataset", "y")(y[:n_train])
params_surjective_mlp_funnel, _ = train(
next(rng_key_seq), data, surjective_mlp_funnel
)
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1000/1000 [00:05<00:00, 167.40it/s]
A surjective affine masked coupling flow#
As a second surjector, we implement a MaskedCouplingInferenceFunnel
with affine transformations. The surjection uses an affine masked coupling layer as inner bijector and a conditional probability density parameterized by an MLP as a decoder. We use the surjection in the middle of five flow layers. The other four are conventional masked coupling flows.
[11]:
def make_surjective_affine_masked_coupling(n_dimensions):
def flow(**kwargs):
def bijector_fn(params):
means, log_scales = jnp.split(params, 2, -1)
return distrax.ScalarAffine(means, jnp.exp(log_scales))
def decoder_fn(n_dim):
def fn(z):
params = make_mlp([32, 32, n_dim * 2])(z)
mu, log_scale = jnp.split(params, 2, -1)
return distrax.Independent(
distrax.Normal(mu, jnp.exp(log_scale))
)
return fn
n_dim = n_dimensions
layers = []
for i in range(5):
if i == 2:
layer = MaskedCouplingInferenceFunnel(
n_keep=int(n_dim / 2),
decoder=decoder_fn(int(n_dim / 2)),
conditioner=make_mlp([128, 128, 2 * n_dim]),
bijector_fn=bijector_fn,
)
n_dim = int(n_dim / 2)
else:
layer = MaskedCoupling(
mask=make_alternating_binary_mask(n_dim, i % 2 == 0),
conditioner=make_mlp([128, 128, 2 * n_dim]),
bijector_fn=bijector_fn,
)
layers.append(layer)
transform = Chain(layers)
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)),
reinterpreted_batch_ndims=1,
)
pushforward = TransformedDistribution(base_distribution, transform)
return pushforward(**kwargs)
td = hk.transform(flow)
return td
[12]:
surjective_affine_masked_coupling = make_surjective_affine_masked_coupling(
p_data
)
data = namedtuple("named_dataset", "y")(y[:n_train])
params_surjective_affine_masked_coupling, _ = train(
next(rng_key_seq), data, surjective_affine_masked_coupling
)
100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1000/1000 [00:26<00:00, 37.88it/s]
A surjective rational quadratic masked coupling flow#
Finally, we implement a MaskedCouplingInferenceFunnel
with a rational quadratic transformations. The flow is the same as before, but with affine transformations replaced with splines.
[13]:
def make_surjective_rq_masked_coupling(n_dimensions):
def flow(**kwargs):
def bijector_fn(params):
return distrax.RationalQuadraticSpline(
params, range_min=range_min, range_max=range_max
)
def decoder_fn(n_dim):
def fn(z):
params = make_mlp([32, 32, n_dim * 2])(z)
mu, log_scale = jnp.split(params, 2, -1)
return distrax.Independent(
distrax.Normal(mu, jnp.exp(log_scale))
)
return fn
n_dim = n_dimensions
layers = []
for i in range(5):
if i == 2:
layer = MaskedCouplingInferenceFunnel(
n_keep=int(n_dim / 2),
decoder=decoder_fn(int(n_dim / 2)),
conditioner=make_rq_conditioner(n_dim, [128, 128], 4),
bijector_fn=bijector_fn,
)
n_dim = int(n_dim / 2)
else:
layer = MaskedCoupling(
mask=make_alternating_binary_mask(n_dim, i % 2 == 0),
conditioner=make_rq_conditioner(n_dim, [128, 128], 4),
bijector_fn=bijector_fn,
)
layers.append(layer)
transform = Chain(layers)
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)),
reinterpreted_batch_ndims=1,
)
pushforward = TransformedDistribution(base_distribution, transform)
return pushforward(**kwargs)
td = hk.transform(flow)
return td
[14]:
surjective_rq_masked_coupling = make_surjective_rq_masked_coupling(p_data)
data = namedtuple("named_dataset", "y")(y[:n_train])
params_surjective_rq_masked_coupling, _ = train(
next(rng_key_seq), data, surjective_rq_masked_coupling
)
100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1000/1000 [01:31<00:00, 10.88it/s]
Density comparisons#
Having trained the baseline and surjectors, letβs compute density estimates of the training and test data sets using the four models.
[15]:
model_list = [
baseline,
surjective_mlp_funnel,
surjective_affine_masked_coupling,
surjective_rq_masked_coupling,
]
param_list = [
params_baseline,
params_surjective_mlp_funnel,
params_surjective_affine_masked_coupling,
params_surjective_rq_masked_coupling,
]
lps = []
for model, params in zip(model_list, param_list):
lp_training = model.apply(params, None, method="log_prob", y=y[:n_train])
lp_test = model.apply(params, None, method="log_prob", y=y[n_train:])
lp_training = jnp.mean(lp_training)
lp_test = jnp.mean(lp_test)
lps.append(np.array([lp_training, lp_test]))
Not so surprisingly the MLP funnel works best on this data set. The baseline that does not reduce dimensionality has the worst performance.
[16]:
df = pd.DataFrame(lps, columns=["Training density", "Test density"])
df.insert(
0,
"Model",
[
"Baseline",
"MLP funnel",
"Affine masked coupling funnel",
"RQ masked coupling funnel",
],
)
df
[16]:
Model | Training density | Test density | |
---|---|---|---|
0 | Baseline | 2.973894 | 1.514763 |
1 | MLP funnel | 10.492463 | 10.487739 |
2 | Affine masked coupling funnel | 10.481473 | 10.244188 |
3 | RQ masked coupling funnel | 6.814237 | 5.800568 |
Session info#
[17]:
import session_info
session_info.show(html=False)
-----
distrax 0.1.5
haiku 0.0.11
jax 0.4.23
jaxlib 0.4.23
numpy 1.26.3
optax 0.1.8
pandas 2.2.0
session_info 1.0.0
surjectors 0.3.0
tqdm 4.66.1
-----
IPython 8.21.0
jupyter_client 8.6.0
jupyter_core 5.7.1
jupyterlab 4.0.12
notebook 7.0.7
-----
Python 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:38:07) [Clang 16.0.6 ]
macOS-13.0.1-arm64-arm-64bit
-----
Session information updated at 2024-02-01 16:57
References#
[1] Klein, Samuel, et al. βFunnels: Exact maximum likelihood with dimensionality reductionβ. Workshop on Bayesian Deep Learning, Advances in Neural Information Processing Systems, 2021.