surjectors.util
#
surjectors.util
contains general utility functions.
|
Create a batch iterator for a data set. |
|
Create a binary masked array. |
|
Unstack a tensor. |
- surjectors.util.as_batch_iterator(rng_key, data, batch_size, shuffle=True)[source]#
Create a batch iterator for a data set.
- Parameters:
rng_key (PRNGKey) β a JAX random key
data (named_dataset) β a data set for which an iterator is created. The data set needs to be a NamedTuple with at least one element being called y. If a conditional flow is to be trained, the second element has to be called x.
batch_size (int) β size of each batch of data that is returned by the iterator
shuffle β if true shuffles the data before creating batches
Examples
>>> from collections import namedtuple >>> from jax import numpy as jnp, random as jr >>> >>> y = jr.normal(jr.PRNGKey(0), (1000, 2)) >>> as_batch_iterator(jr.PRNGKey(1), namedtuple("data", "y")(y), 100) >>> >>> x = jr.normal(jr.PRNGKey(1), (1000, 2)) >>> as_batch_iterator( >>> jr.PRNGKey(1), namedtuple("data", "y x")(y, x), 100 >>> )
- Returns:
a data loader object
- Parameters:
rng_key (PRNGKey) β
data (named_dataset) β
batch_size (int) β
- surjectors.util.make_alternating_binary_mask(n_dim, even_idx_as_true=False)[source]#
Create a binary masked array.
- Parameters:
n_dim (int) β length of the masked array to be created
even_idx_as_true (bool) β a boolean indicating which indices are set to zero. If even_idx_as_true=True sets all even indices [0, 2, 4, β¦] to True
- Returns:
boolean masked array where every even or uneven index is True