surjectors.util#


surjectors.util contains general utility functions.

as_batch_iterator(rng_key,Β data,Β batch_size)

Create a batch iterator for a data set.

make_alternating_binary_mask(n_dim[,Β ...])

Create a binary masked array.

unstack(x[,Β axis])

Unstack a tensor.

surjectors.util.as_batch_iterator(rng_key, data, batch_size, shuffle=True)[source]#

Create a batch iterator for a data set.

Parameters:
  • rng_key (PRNGKey) – a JAX random key

  • data (named_dataset) – a data set for which an iterator is created. The data set needs to be a NamedTuple with at least one element being called y. If a conditional flow is to be trained, the second element has to be called x.

  • batch_size (int) – size of each batch of data that is returned by the iterator

  • shuffle – if true shuffles the data before creating batches

Examples

>>> from collections import namedtuple
>>> from jax import numpy as jnp, random as jr
>>>
>>> y = jr.normal(jr.PRNGKey(0), (1000, 2))
>>> as_batch_iterator(jr.PRNGKey(1), namedtuple("data", "y")(y), 100)
>>>
>>> x = jr.normal(jr.PRNGKey(1), (1000, 2))
>>> as_batch_iterator(
>>>     jr.PRNGKey(1), namedtuple("data", "y x")(y, x), 100
>>> )
Returns:

a data loader object

Parameters:
  • rng_key (PRNGKey) –

  • data (named_dataset) –

  • batch_size (int) –

surjectors.util.make_alternating_binary_mask(n_dim, even_idx_as_true=False)[source]#

Create a binary masked array.

Parameters:
  • n_dim (int) – length of the masked array to be created

  • even_idx_as_true (bool) – a boolean indicating which indices are set to zero. If even_idx_as_true=True sets all even indices [0, 2, 4, …] to True

Returns:

boolean masked array where every even or uneven index is True

surjectors.util.unstack(x, axis=0)[source]#

Unstack a tensor.

Unstack a tensor as tf.unstack does

Parameters:
  • x – array to unstack

  • axis – the axis as integer index

Returns:

unstacked array