Source code for surjectors._src.bijectors.lu_linear

import haiku as hk
import jax.nn
import numpy as np
from jax import numpy as jnp

from surjectors._src.bijectors.bijector import Bijector


# pylint: disable=arguments-differ,too-many-instance-attributes
[docs] class LULinear(Bijector, hk.Module): """An bijection based on the LU composition. Args: n_dimension: number of dimensions to keep with_bias: use a bias term or not dtype: parameter dtype References: .. [1] Oliva, Junier, et al. "Transformation Autoregressive Networks". Proceedings of the 35th International Conference on Machine Learning, 2018. Examples: >>> import haiku as hk >>> from jax import random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd >>> from surjectors import LULinear, TransformedDistribution >>> @hk.without_apply_rng ... @hk.transform ... def fn(inputs): ... base_distribution = tfd.Independent( ... tfd.Normal(jnp.zeros(5), jnp.ones(5)), ... reinterpreted_batch_ndims=1, ... ) ... td = TransformedDistribution(base_distribution, LULinear(5)) ... return td.log_prob(inputs) >>> >>> data = jr.normal(jr.PRNGKey(1), shape=(10, 5)) >>> params = fn.init(jr.key(0), data) >>> lps = fn.apply(params, data) """ def __init__(self, n_dimension, with_bias=False, dtype=jnp.float32): super().__init__() if with_bias: raise NotImplementedError() self.n_dimension = n_dimension self.with_bias = with_bias self.dtype = dtype n_triangular_entries = ((n_dimension - 1) * n_dimension) // 2 self._lower_indices = np.tril_indices(n_dimension, k=-1) self._upper_indices = np.triu_indices(n_dimension, k=1) self._diag_indices = np.diag_indices(n_dimension) self._lower_entries = hk.get_parameter( "lower_entries", [n_triangular_entries], dtype=dtype, init=jnp.zeros ) self._upper_entries = hk.get_parameter( "upper_entries", [n_triangular_entries], dtype=dtype, init=jnp.zeros ) self._unconstrained_upper_diag_entries = hk.get_parameter( "diag_entries", [n_dimension], dtype=dtype, init=jnp.ones ) def _to_lower_and_upper_matrices(self): L = jnp.zeros((self.n_dimension, self.n_dimension), dtype=self.dtype) L = L.at[self._lower_indices].set(self._lower_entries) L = L.at[self._diag_indices].set(1.0) U = jnp.zeros((self.n_dimension, self.n_dimension), dtype=self.dtype) U = U.at[self._upper_indices].set(self._upper_entries) U = U.at[self._diag_indices].set(self._upper_diag) return L, U @property def _upper_diag(self): return jax.nn.softplus(self._unconstrained_upper_diag_entries) + 1e-4 def _inverse_likelihood_contribution(self): return jnp.sum(jnp.log(self._upper_diag)) def _inverse_and_likelihood_contribution(self, y, x=None, **kwargs): L, U = self._to_lower_and_upper_matrices() z = jnp.dot(jnp.dot(y, U), L) lc = self._inverse_likelihood_contribution() * jnp.ones(z.shape[0]) return z, lc def _forward_and_likelihood_contribution(self, z, x=None, **kwargs): raise NotImplementedError()