surjectors.util#
surjectors.util contains general utility functions.
|
Create a batch iterator for a data set. |
|
Create a binary masked array. |
|
Unstack a tensor. |
- surjectors.util.as_batch_iterator(rng_key, data, batch_size, shuffle=True)[source]#
Create a batch iterator for a data set.
- Return type:
_DataLoader- Parameters:
rng_key (PRNGKey) โ a JAX random key
data (named_dataset) โ a data set for which an iterator is created. The data set needs to be a NamedTuple with at least one element being called y. If a conditional flow is to be trained, the second element has to be called x.
batch_size (int) โ size of each batch of data that is returned by the iterator
shuffle โ if true shuffles the data before creating batches
Examples
>>> from collections import namedtuple >>> from jax import numpy as jnp, random as jr >>> >>> y = jr.normal(jr.PRNGKey(0), (1000, 2)) >>> loader = as_batch_iterator( ... jr.PRNGKey(1), namedtuple("data", "y")(y), 100 ... ) >>> x = jr.normal(jr.PRNGKey(1), (1000, 2)) >>> loader = as_batch_iterator( ... jr.PRNGKey(1), namedtuple("data", "y x")(y, x), 100 ... )
- Returns:
a data loader object
- Parameters:
rng_key (PRNGKey)
data (named_dataset)
batch_size (int)
- Return type:
_DataLoader
- surjectors.util.make_alternating_binary_mask(n_dim, even_idx_as_true=False)[source]#
Create a binary masked array.
- Parameters:
n_dim (int) โ length of the masked array to be created
even_idx_as_true (bool) โ a boolean indicating which indices are set to zero. If even_idx_as_true=True sets all even indices [0, 2, 4, โฆ] to True
- Returns:
boolean masked array where every even or uneven index is True