surjectors.util#


surjectors.util contains general utility functions.

as_batch_iterator(rng_key,ย data,ย batch_size)

Create a batch iterator for a data set.

make_alternating_binary_mask(n_dim[,ย ...])

Create a binary masked array.

unstack(x[,ย axis])

Unstack a tensor.

surjectors.util.as_batch_iterator(rng_key, data, batch_size, shuffle=True)[source]#

Create a batch iterator for a data set.

Return type:

_DataLoader

Parameters:
  • rng_key (PRNGKey) โ€“ a JAX random key

  • data (named_dataset) โ€“ a data set for which an iterator is created. The data set needs to be a NamedTuple with at least one element being called y. If a conditional flow is to be trained, the second element has to be called x.

  • batch_size (int) โ€“ size of each batch of data that is returned by the iterator

  • shuffle โ€“ if true shuffles the data before creating batches

Examples

>>> from collections import namedtuple
>>> from jax import numpy as jnp, random as jr
>>>
>>> y = jr.normal(jr.PRNGKey(0), (1000, 2))
>>> loader = as_batch_iterator(
...   jr.PRNGKey(1), namedtuple("data", "y")(y), 100
... )
>>> x = jr.normal(jr.PRNGKey(1), (1000, 2))
>>> loader = as_batch_iterator(
...   jr.PRNGKey(1), namedtuple("data", "y x")(y, x), 100
... )
Returns:

a data loader object

Parameters:
  • rng_key (PRNGKey)

  • data (named_dataset)

  • batch_size (int)

Return type:

_DataLoader

surjectors.util.make_alternating_binary_mask(n_dim, even_idx_as_true=False)[source]#

Create a binary masked array.

Parameters:
  • n_dim (int) โ€“ length of the masked array to be created

  • even_idx_as_true (bool) โ€“ a boolean indicating which indices are set to zero. If even_idx_as_true=True sets all even indices [0, 2, 4, โ€ฆ] to True

Returns:

boolean masked array where every even or uneven index is True

surjectors.util.unstack(x, axis=0)[source]#

Unstack a tensor.

Unstack a tensor as tf.unstack does

Parameters:
  • x โ€“ array to unstack

  • axis โ€“ the axis as integer index

Returns:

unstacked array