Source code for surjectors.util

"""Utility functions."""

from collections import namedtuple

import numpy as np
from jax import lax
from jax import numpy as jnp
from jax import random as jr

__all__ = ["make_alternating_binary_mask", "as_batch_iterator", "unstack"]

named_dataset = namedtuple("named_dataset", "y x")


# pylint: disable=too-few-public-methods
class _DataLoader:
  """Dataloader class."""

  def __init__(self, num_batches, idxs, get_batch):
    self.num_batches = num_batches
    self.idxs = idxs
    self.get_batch = get_batch

  def __call__(self, idx, idxs=None):
    if idxs is None:
      idxs = self.idxs
    return self.get_batch(idx, idxs)


[docs] def make_alternating_binary_mask(n_dim: int, even_idx_as_true: bool = False): """Create a binary masked array. Args: n_dim: length of the masked array to be created even_idx_as_true: a boolean indicating which indices are set to zero. If even_idx_as_true=True sets all even indices [0, 2, 4, ...] to True Returns: boolean masked array where every even or uneven index is True """ mask = jnp.arange(0, np.prod(n_dim)) % 2 mask = jnp.reshape(mask, n_dim) mask = mask.astype(bool) if even_idx_as_true: mask = jnp.logical_not(mask) return mask
[docs] def as_batch_iterator( rng_key: jr.PRNGKey, data: named_dataset, batch_size: int, shuffle=True ) -> _DataLoader: """Create a batch iterator for a data set. Args: rng_key: a JAX random key data: a data set for which an iterator is created. The data set needs to be a NamedTuple with at least one element being called `y`. If a conditional flow is to be trained, the second element has to be called `x`. batch_size: size of each batch of data that is returned by the iterator shuffle: if true shuffles the data before creating batches Examples: >>> from collections import namedtuple >>> from jax import numpy as jnp, random as jr >>> >>> y = jr.normal(jr.PRNGKey(0), (1000, 2)) >>> loader = as_batch_iterator( ... jr.PRNGKey(1), namedtuple("data", "y")(y), 100 ... ) >>> x = jr.normal(jr.PRNGKey(1), (1000, 2)) >>> loader = as_batch_iterator( ... jr.PRNGKey(1), namedtuple("data", "y x")(y, x), 100 ... ) Returns: a data loader object """ n = data.y.shape[0] if n < batch_size: num_batches = 1 batch_size = n elif n % batch_size == 0: num_batches = int(n // batch_size) else: num_batches = int(n // batch_size) + 1 idxs = jnp.arange(n) if shuffle: idxs = jr.permutation(rng_key, idxs) def get_batch(idx, idxs=idxs): start_idx = idx * batch_size step_size = jnp.minimum(n - start_idx, batch_size) ret_idx = lax.dynamic_slice_in_dim(idxs, idx * batch_size, step_size) batch = { name: lax.index_take(array, (ret_idx,), axes=(0,)) for name, array in zip(data._fields, data) } return batch return _DataLoader(num_batches, idxs, get_batch)
[docs] def unstack(x, axis=0): """Unstack a tensor. Unstack a tensor as tf.unstack does Args: x: array to unstack axis: the axis as integer index Returns: unstacked array """ return [ lax.index_in_dim(x, i, axis, keepdims=False) for i in range(x.shape[axis]) ]