• Stars
    star
    716
  • Rank 60,943 (Top 2 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created almost 4 years ago
  • Updated 18 days ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Chex

CI status docs pypi

Chex is a library of utilities for helping to write reliable JAX code.

This includes utils to help:

  • Instrument your code (e.g. assertions)
  • Debug (e.g. transforming pmaps in vmaps within a context manager).
  • Test JAX code across many variants (e.g. jitted vs non-jitted).

Installation

You can install the latest released version of Chex from PyPI via:

pip install chex

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/chex.git

Modules Overview

Dataclass (dataclass.py)

Dataclasses are a popular construct introduced by Python 3.7 to allow to easily specify typed data structures with minimal boilerplate code. They are not, however, compatible with JAX and dm-tree out of the box.

In Chex we provide a JAX-friendly dataclass implementation reusing python dataclasses.

Chex implementation of dataclass registers dataclasses as internal PyTree nodes to ensure compatibility with JAX data structures.

In addition, we provide a class wrapper that exposes dataclasses as collections.Mapping descendants which allows to process them (e.g. (un-)flatten) in dm-tree methods as usual Python dictionaries. See @mappable_dataclass docstring for more details.

Example:

@chex.dataclass
class Parameters:
  x: chex.ArrayDevice
  y: chex.ArrayDevice

parameters = Parameters(
    x=jnp.ones((2, 2)),
    y=jnp.ones((1, 2)),
)

# Dataclasses can be treated as JAX pytrees
jax.tree_util.tree_map(lambda x: 2.0 * x, parameters)

# and as mappings by dm-tree
tree.flatten(parameters)

NOTE: Unlike standard Python 3.7 dataclasses, Chex dataclasses cannot be constructed using positional arguments. They support construction arguments provided in the same format as the Python dict constructor. Dataclasses can be converted to tuples with the from_tuple and to_tuple methods if necessary.

parameters = Parameters(
    jnp.ones((2, 2)),
    jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.

Assertions (asserts.py)

One limitation of PyType annotations for JAX is that they do not support the specification of DeviceArray ranks, shapes or dtypes. Chex includes a number of functions that allow flexible and concise specification of these properties.

E.g. suppose you want to ensure that all tensors t1, t2, t3 have the same shape, and that tensors t4, t5 have rank 2 and (3 or 4), respectively.

chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])

More examples:

from chex import assert_shape, assert_rank, ...

assert_shape(x, (2, 3))                # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)])      # x is scalar and y has shape (2, 3)

assert_rank(x, 0)                      # x is scalar
assert_rank([x, y], [0, 2])            # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2})            # x and y are scalar OR rank-2 arrays

assert_type(x, int)                    # x has type `int` (x can be an array)
assert_type([x, y], [int, float])      # x has type `int` and y has type `float`

assert_equal_shape([x, y, z])          # x, y, and z have equal shapes

assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x)         # all tree_x leaves are finite

assert_devices_available(2, 'gpu')     # 2 GPUs available
assert_tpu_available()                 # at least 1 TPU available

assert_numerical_grads(f, (x, y), j)   # f^{(j)}(x, y) matches numerical grads

See asserts.py documentation to find all supported assertions.

If you cannot find a specific assertion, please consider making a pull request or openning an issue on the bug tracker.

Optional Arguments

All chex assertions support the following optional kwargs for manipulating the emitted exception messages:

  • custom_message: A string to include into the emitted exception messages.
  • include_default_message: Whether to include the default Chex message into the emitted exception messages.
  • exception_type: An exception type to use. AssertionError by default.

For example, the following code:

dataset = load_dataset()
params = init_params()
for i in range(num_steps):
  params = update_params(params, dataset.sample())
  chex.assert_tree_all_finite(params,
                              custom_message=f'Failed at iteration {i}.',
                              exception_type=ValueError)

will raise a ValueError that includes a step number when params get polluted with NaNs or Nones.

Static and Value (aka Runtime) Assertions

Chex divides all assertions into 2 classes: static and value assertions.

  1. static assertions use anything except concrete values of tensors. Examples: assert_shape, assert_trees_all_equal_dtypes, assert_max_traces.

  2. value assertions require access to tensor values, which are not available during JAX tracing (see HowJAX primitives work), thus such assertion need special treatment in a jitted code.

To enable value assertions in a jitted function, it can be decorated with chex.chexify() wrapper. Example:

  @chex.chexify
  @jax.jit
  def logp1_abs_safe(x: chex.Array) -> chex.Array:
    chex.assert_tree_all_finite(x)
    return jnp.log(jnp.abs(x) + 1)

  logp1_abs_safe(jnp.ones(2))  # OK
  logp1_abs_safe(jnp.array([jnp.nan, 3]))  # FAILS (in async mode)

  # The error will be raised either at the next line OR at the next
  # `logp1_abs_safe` call. See the docs for more detain on async mode.
  logp1_abs_safe.wait_checks()  # Wait for the (async) computation to complete.

See this docstring for more detail on chex.chexify().

JAX Tracing Assertions

JAX re-traces JIT'ted function every time the structure of passed arguments changes. Often this behavior is inadvertent and leads to a significant performance drop which is hard to debug. @chex.assert_max_traces decorator asserts that the function is not re-traced more than n times during program execution.

Global trace counter can be cleared by calling chex.clear_trace_counter(). This function be used to isolate unittests relying on @chex.assert_max_traces.

Examples:

  @jax.jit
  @chex.assert_max_traces(n=1)
  def fn_sum_jitted(x, y):
    return x + y

  fn_sum_jitted(jnp.zeros(3), jnp.zeros(3))  # tracing for the 1st time - OK
  fn_sum_jitted(jnp.zeros([6, 7]), jnp.zeros([6, 7]))  # AssertionError!

Can be used with jax.pmap() as well:

  def fn_sub(x, y):
    return x - y

  fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))

See HowJAX primitives work section for more information about tracing.

Test variants (variants.py)

JAX relies extensively on code transformation and compilation, meaning that it can be hard to ensure that code is properly tested. For instance, just testing a python function using JAX code will not cover the actual code path that is executed when jitted, and that path will also differ whether the code is jitted for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs where XLA changes would lead to undesirable behaviours that however only manifest in one specific code transformation.

Variants make it easy to ensure that unit tests cover different ‘variations’ of a function, by providing a simple decorator that can be used to repeat any test under all (or a subset) of the relevant code transformations.

E.g. suppose you want to test the output of a function fn with or without jit. You can use chex.variants to run the test with both the jitted and non-jitted version of the function by simply decorating a test method with @chex.variants, and then using self.variant(fn) in place of fn in the body of the test.

def fn(x, y):
  return x + y
...

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    var_fn = self.variant(fn)
    self.assertEqual(fn(1, 2), 3)
    self.assertEqual(var_fn(1, 2), fn(1, 2))

If you define the function in the test method, you may also use self.variant as a decorator in the function definition. For example:

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(1, 2), 3)

Example of parameterized test:

from absl.testing import parameterized

# Could also be:
#  `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
#  `class ExampleParameterizedTest(chex.TestCase):`
class ExampleParameterizedTest(parameterized.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  @parameterized.named_parameters(
      ('case_positive', 1, 2, 3),
      ('case_negative', -1, -2, -3),
  )
  def test(self, arg_1, arg_2, expected):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(arg_1, arg_2), expected)

Chex currently supports the following variants:

  • with_jit -- applies jax.jit() transformation to the function.
  • without_jit -- uses the function as is, i.e. identity transformation.
  • with_device -- places all arguments (except specified in ignore_argnums argument) into device memory before applying the function.
  • without_device -- places all arguments in RAM before applying the function.
  • with_pmap -- applies jax.pmap() transformation to the function (see notes below).

See documentation in variants.py for more details on the supported variants. More examples can be found in variants_test.py.

Variants notes

  • Test classes that use @chex.variants must inherit from chex.TestCase (or any other base class that unrolls tests generators within TestCase, e.g. absl.testing.parameterized.TestCase).

  • [jax.vmap] All variants can be applied to a vmapped function; please see an example in variants_test.py (test_vmapped_fn_named_params and test_pmap_vmapped_fn).

  • [@chex.all_variants] You can get all supported variants by using the decorator @chex.all_variants.

  • [with_pmap variant] jax.pmap(fn) (doc) performs parallel map of fn onto multiple devices. Since most tests run in a single-device environment (i.e. having access to a single CPU or GPU), in which case jax.pmap is a functional equivalent to jax.jit, with_pmap variant is skipped by default (although it works fine with a single device). Below we describe a way to properly test fn if it is supposed to be used in multi-device environments (TPUs or multiple CPUs/GPUs). To disable skipping with_pmap variants in case of a single device, add --chex_skip_pmap_variant_if_single_device=false to your test command.

Fakes (fake.py)

Debugging in JAX is made more difficult by code transformations such as jit and pmap, which introduce optimizations that make code hard to inspect and trace. It can also be difficult to disable those transformations during debugging as they can be called at several places in the underlying code. Chex provides tools to globally replace jax.jit with a no-op transformation and jax.pmap with a (non-parallel) jax.vmap, in order to more easily debug code in a single-device context.

For example, you can use Chex to fake pmap and have it replaced with a vmap. This can be achieved by wrapping your code with a context manager:

with chex.fake_pmap():
  @jax.pmap
  def fn(inputs):
    ...

  # Function will be vmapped over inputs
  fn(inputs)

The same functionality can also be invoked with start and stop:

fake_pmap = chex.fake_pmap()
fake_pmap.start()
... your jax code ...
fake_pmap.stop()

In addition, you can fake a real multi-device test environment with a multi-threaded CPU. See section Faking multi-device test environments for more details.

See documentation in fake.py and examples in fake_test.py for more details.

Faking multi-device test environments

In situations where you do not have easy access to multiple devices, you can still test parallel computation using single-device multi-threading.

In particular, one can force XLA to use a single CPU's threads as separate devices, i.e. to fake a real multi-device environment with a multi-threaded one. These two options are theoretically equivalent from XLA perspective because they expose the same interface and use identical abstractions.

Chex has a flag chex_n_cpu_devices that specifies a number of CPU threads to use as XLA devices.

To set up a multi-threaded XLA environment for absl tests, define setUpModule function in your test module:

def setUpModule():
  chex.set_n_cpu_devices()

Now you can launch your test with python test.py --chex_n_cpu_devices=N to run it in multi-device regime. Note that all tests within a module will have an access to N devices.

More examples can be found in variants_test.py, fake_test.py and fake_set_n_cpu_devices_test.py.

Using named dimension sizes.

Chex comes with a small utility that allows you to package a collection of dimension sizes into a single object. The basic idea is:

dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
chex.assert_shape(arr, dims['BTE'])

String lookups are translated integer tuples. For instance, let's say batch_size == 3, sequence_len = 5 and embedding_dim = 7, then

dims['BTE'] == (3, 5, 7)
dims['B'] == (3,)
dims['TTBEE'] == (5, 5, 3, 7, 7)
...

You can also assign dimension sizes dynamically as follows:

dims['XY'] = some_matrix.shape
dims.Z = 13

For more examples, see chex.Dimensions documentation.

Citing Chex

This repository is part of the DeepMind JAX Ecosystem, to cite Chex please use the DeepMind JAX Ecosystem citation.

More Repositories

1

deepmind-research

This repository contains implementations and illustrative code to accompany DeepMind publications
Jupyter Notebook
12,817
star
2

alphafold

Open source code for AlphaFold.
Python
11,700
star
3

sonnet

TensorFlow-based neural network library
Python
9,691
star
4

pysc2

StarCraft II Learning Environment
Python
7,904
star
5

mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
Jupyter Notebook
7,202
star
6

lab

A customisable 3D platform for agent-based AI research
C
7,012
star
7

graph_nets

Build Graph Nets in Tensorflow
Python
5,325
star
8

graphcast

Python
4,242
star
9

learning-to-learn

Learning to Learn in TensorFlow
Python
4,063
star
10

open_spiel

OpenSpiel is a collection of environments and algorithms for research in general reinforcement learning and search/planning in games.
C++
4,019
star
11

alphageometry

Python
3,580
star
12

dm_control

Google DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo.
Python
3,473
star
13

acme

A library of reinforcement learning components and agents
Python
3,372
star
14

trfl

TensorFlow Reinforcement Learning
Python
3,139
star
15

dm-haiku

JAX-based neural network library
Python
2,697
star
16

alphatensor

Python
2,616
star
17

dnc

A TensorFlow implementation of the Differentiable Neural Computer.
Python
2,478
star
18

mctx

Monte Carlo tree search in JAX
Python
2,209
star
19

gemma

Open weights LLM from Google DeepMind.
Jupyter Notebook
2,061
star
20

code_contests

C++
2,010
star
21

kinetics-i3d

Convolutional neural network model for video classification trained on the Kinetics dataset.
Python
1,639
star
22

mathematics_dataset

This dataset code generates mathematical question and answer pairs, from a range of question types at roughly school-level difficulty.
Python
1,621
star
23

optax

Optax is a gradient processing and optimization library for JAX.
Python
1,492
star
24

bsuite

bsuite is a collection of carefully-designed experiments that investigate core capabilities of a reinforcement learning (RL) agent
Python
1,465
star
25

penzai

A JAX research toolkit for building, editing, and visualizing neural networks.
Python
1,405
star
26

educational

Jupyter Notebook
1,382
star
27

jraph

A Graph Neural Network Library in Jax
Python
1,306
star
28

rc-data

Question answering dataset featured in "Teaching Machines to Read and Comprehend
Python
1,285
star
29

rlax

Python
1,185
star
30

tapnet

Tracking Any Point (TAP)
Python
1,033
star
31

scalable_agent

A TensorFlow implementation of Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures.
Python
972
star
32

neural-processes

This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).
Jupyter Notebook
966
star
33

android_env

RL research on Android devices.
Python
946
star
34

mujoco_menagerie

A collection of high-quality models for the MuJoCo physics engine, curated by Google DeepMind.
Jupyter Notebook
926
star
35

dramatron

Dramatron uses large language models to generate coherent scripts and screenplays.
Jupyter Notebook
904
star
36

tree

tree is a library for working with nested data structures
Python
891
star
37

xmanager

A platform for managing machine learning experiments
Python
796
star
38

mujoco_mpc

Real-time behaviour synthesis with MuJoCo, using Predictive Control
C++
771
star
39

materials_discovery

Python
770
star
40

reverb

Reverb is an efficient and easy-to-use data storage and transport system designed for machine learning research
C++
692
star
41

alphadev

Python
662
star
42

pycolab

A highly-customisable gridworld game engine with some batteries included. Make your own gridworld games to test reinforcement learning agents!
Python
654
star
43

ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Python
643
star
44

hanabi-learning-environment

hanabi_learning_environment is a research platform for Hanabi experiments.
Python
614
star
45

funsearch

Jupyter Notebook
611
star
46

ai-safety-gridworlds

This is a suite of reinforcement learning environments illustrating various safety properties of intelligent agents.
Python
577
star
47

dqn

Lua/Torch implementation of DQN (Nature, 2015)
Lua
546
star
48

ithaca

Restoring and attributing ancient texts using deep neural networks
Jupyter Notebook
540
star
49

meltingpot

A suite of test scenarios for multi-agent reinforcement learning.
Python
516
star
50

distrax

Python
509
star
51

recurrentgemma

Open weights language model from Google DeepMind, based on Griffin.
Python
505
star
52

surface-distance

Library to compute surface distance based performance metrics for segmentation tasks.
Python
493
star
53

tracr

Python
467
star
54

dsprites-dataset

Dataset to assess the disentanglement properties of unsupervised learning methods
Jupyter Notebook
463
star
55

alphamissense

Python
455
star
56

narrativeqa

This repository contains the NarrativeQA dataset. It includes the list of documents with Wikipedia summaries, links to full stories, and questions and answers.
Shell
432
star
57

lab2d

A customisable 2D platform for agent-based AI research
C++
415
star
58

open_x_embodiment

Jupyter Notebook
409
star
59

dqn_zoo

DQN Zoo is a collection of reference implementations of reinforcement learning agents developed at DeepMind based on the Deep Q-Network (DQN) agent.
Python
406
star
60

clrs

Python
376
star
61

spriteworld

Spriteworld: a flexible, configurable python-based reinforcement learning environment
Python
366
star
62

dm_pix

PIX is an image processing library in JAX, for JAX.
Python
363
star
63

concordia

A library for generative social simulation
Python
351
star
64

mathematics_conjectures

Jupyter Notebook
348
star
65

alphastar

Python
346
star
66

spiral

We provide a pre-trained model for unconditional 19-step generation of CelebA-HQ images
C++
327
star
67

dm_env

A Python interface for reinforcement learning environments
Python
326
star
68

dm_robotics

Libraries, tools and tasks created and used at DeepMind Robotics.
Python
315
star
69

uncertain_ground_truth

Dermatology ddx dataset, Jax implementations of Monte Carlo conformal prediction, plausibility regions and statistical annotation aggregation from our recent work on uncertain ground truth (TMLR'23 and ArXiv pre-print).
Python
315
star
70

long-form-factuality

Benchmarking long-form factuality in large language models. Original code for our paper "Long-form factuality in large language models".
Python
314
star
71

launchpad

Python
305
star
72

leo

Implementation of Meta-Learning with Latent Embedding Optimization
Python
302
star
73

streetlearn

A C++/Python implementation of the StreetLearn environment based on images from Street View, as well as a TensorFlow implementation of goal-driven navigation agents solving the task published in “Learning to Navigate in Cities Without a Map”, NeurIPS 2018
C++
279
star
74

gqn-datasets

Datasets used to train Generative Query Networks (GQNs) in the ‘Neural Scene Representation and Rendering’ paper.
Python
269
star
75

enn

Python
265
star
76

multi_object_datasets

Multi-object image datasets with ground-truth segmentation masks and generative factors.
Python
247
star
77

AQuA

A algebraic word problem dataset, with multiple choice questions annotated with rationales.
238
star
78

card2code

A code generation dataset for generating the code that implements Hearthstone and Magic The Gathering card effects.
236
star
79

grid-cells

Implementation of the supervised learning experiments in Vector-based navigation using grid-like representations in artificial agents, as published at https://www.nature.com/articles/s41586-018-0102-6
Python
236
star
80

arnheim

Jupyter Notebook
235
star
81

synjax

Python
233
star
82

torch-hdf5

Torch interface to HDF5 library
Lua
231
star
83

dm_memorytasks

A set of 13 diverse machine-learning tasks that require memory to solve.
Python
220
star
84

Temporal-3D-Pose-Kinetics

Exploiting temporal context for 3D human pose estimation in the wild: 3D poses for the Kinetics dataset
Python
214
star
85

opro

official code for "Large Language Models as Optimizers"
Python
199
star
86

dm_alchemy

DeepMind Alchemy task environment: a meta-reinforcement learning benchmark
Python
197
star
87

neural_testbed

Jupyter Notebook
187
star
88

kfac-jax

Second Order Optimization and Curvature Estimation with K-FAC in JAX.
Python
187
star
89

pg19

179
star
90

xquad

173
star
91

jmp

JMP is a Mixed Precision library for JAX.
Python
171
star
92

spectral_inference_networks

Implementation of Spectral Inference Networks, ICLR 2019
Python
165
star
93

abstract-reasoning-matrices

Progressive matrices dataset, as described in: Measuring abstract reasoning in neural networks (Barrett*, Hill*, Santoro*, Morcos, Lillicrap), ICML2018
162
star
94

neural_networks_chomsky_hierarchy

Neural Networks and the Chomsky Hierarchy
Python
162
star
95

xitari

This is the 0.4 release of the Arcade Learning Environment (ALE), a platform designed for AI research. ALE is based on Stella, an Atari 2600 VCS emulator.
C++
159
star
96

tensor_annotations

Annotating tensor shapes using Python types
Python
158
star
97

symplectic-gradient-adjustment

A colab that implements the Symplectic Gradient Adjustment optimizer from "The mechanics of n-player differentiable games"
Jupyter Notebook
150
star
98

mc_gradients

Jupyter Notebook
149
star
99

interval-bound-propagation

This repository contains a simple implementation of Interval Bound Propagation (IBP) using TensorFlow: https://arxiv.org/abs/1810.12715
Python
148
star
100

s6

C++
146
star