Source code for surjectors._src.distributions.transformed_distribution
import chex
import distrax
import haiku as hk
from distrax import Distribution
from jax import Array
from tensorflow_probability.substrates.jax import distributions as tfd
from surjectors._src.surjectors.surjector import Surjector
[docs]
class TransformedDistribution:
"""Distribution of a random variable transformed by a function.
Can be used to define a pushforward measure.
Args:
base_distribution: a distribution object
transform: some transformation
Examples:
>>> import distrax
>>> from jax import numpy as jnp
>>> from surjectors import Slice, Chain, TransformedDistribution
>>>
>>> a = Slice(10)
>>> b = Slice(5)
>>> ab = Chain([a, b])
>>>
>>> TransformedDistribution(
>>> distrax.Normal(jnp.zeros(5), jnp.ones(5)),
>>> Chain([a, b])
>>> )
"""
def __init__(
self,
base_distribution: Distribution | tfd.Distribution,
transform: Surjector,
):
self.base_distribution = base_distribution
self.transform = transform
def __call__(self, method, **kwargs):
"""Call the TransformedDistribution object.
Depending on "method", computes log-probability of an event or samples
from the distribution.
Args:
method: either "sample" or "log_prob"
**kwargs: several keyword arguments that are dispatched to
whatever method is called.
Returns:
returns whatever 'method' returns
"""
return getattr(self, method)(**kwargs)
[docs]
def log_prob(self, y: Array, x: Array = None):
"""Calculate the log probability of an event conditional on another.
Args:
y: event for which the log probability is computed
x: optional event that is used to condition
Returns:
array of floats of log probabilities
"""
_, lp = self.inverse_and_log_prob(y, x)
return lp
[docs]
def inverse_and_log_prob(self, y: Array, x: Array = None):
"""Compute the inverse transformation and its log probability.
Args:
y: event for which the inverse and log probability is computed
x: optional event that is used to condition
Returns:
tuple of two arrays of floats. The first one is the inverse
transformation, the second one is the log probability
"""
if x is not None:
chex.assert_equal_rank([y, x])
chex.assert_axis_dimension(y, 0, x.shape[0])
if isinstance(self.transform, distrax.Bijector):
z, lc = self.transform.inverse_and_log_det(y)
else:
z, lc = self.transform.inverse_and_likelihood_contribution(y, x=x)
lp_z = self.base_distribution.log_prob(z)
lp = lp_z + lc
return z, lp
[docs]
def sample(self, sample_shape=(), x: Array = None):
"""Sample an event.
Args:
sample_shape: the size of the sample to be drawn
x: optional event that is used to condition the samples.
If x is given sample_shape is ignored
Returns:
a sample from the transformed distribution
"""
y, _ = self.sample_and_log_prob(sample_shape, x)
return y
[docs]
def sample_and_log_prob(self, sample_shape=(), x: Array = None):
"""Sample an event and compute its log probability.
Args:
sample_shape: the size of the sample to be drawn
x: optional event that is used to condition the samples.
If x is given sample_shape is ignored
Returns:
tuple of two arrays of floats. The first one is the drawn sample
transformation, the second one is its log probability
"""
if x is not None and len(sample_shape) == 0:
sample_shape = (x.shape[0],)
if x is not None:
chex.assert_equal(sample_shape[0], x.shape[0])
try:
z, lp_z = self.base_distribution.sample_and_log_prob(
seed=hk.next_rng_key(),
sample_shape=sample_shape,
)
except AttributeError:
z, lp_z = self.base_distribution.experimental_sample_and_log_prob(
seed=hk.next_rng_key(),
sample_shape=sample_shape,
)
y, fldj = self.transform.forward_and_likelihood_contribution(z, x=x)
lp = lp_z - fldj
return y, lp