Source code for surjectors._src.bijectors.permutation
from jax import numpy as jnp
from surjectors._src.bijectors.bijector import Bijector
# pylint: disable=arguments-renamed
[docs]
class Permutation(Bijector):
"""Permute the dimensions of a vector.
Args:
permutation: a vector of integer indexes representing the order of
the elements
event_ndims_in: number of input event dimensions
Examples:
>>> from surjectors import Permutation
>>> from jax import numpy as jnp
>>>
>>> order = jnp.arange(10)
>>> perm = Permutation(order, 1)
"""
def __init__(self, permutation, event_ndims_in: int):
self.permutation = permutation
self.event_ndims_in = event_ndims_in
def _forward_and_likelihood_contribution(self, z, **kwargs):
return z[..., self.permutation], jnp.full(jnp.shape(z)[:-1], 0.0)
def _inverse_and_likelihood_contribution(self, y, **kwargs):
size = self.permutation.size
permutation_inv = (
jnp.zeros(size, dtype=jnp.result_type(int))
.at[self.permutation]
.set(jnp.arange(size))
)
return y[..., permutation_inv], jnp.full(jnp.shape(y)[:-1], 0.0)