• Stars
    star
    2,697
  • Rank 16,289 (Top 0.4 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created about 4 years ago
  • Updated 3 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

JAX-based neural network library

Haiku: Sonnet for JAX

Overview | Why Haiku? | Quickstart | Installation | Examples | User manual | Documentation | Citing Haiku

pytest docs pypi

What is Haiku?

Haiku is a tool
For building neural networks
Think: "Sonnet for JAX"

Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow.

Documentation on Haiku can be found at https://dm-haiku.readthedocs.io/.

Disambiguation: if you are looking for Haiku the operating system then please see https://haiku-os.org/.

Overview

JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.

Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations.

Haiku provides two core tools: a module abstraction, hk.Module, and a simple function transformation, hk.transform.

hk.Modules are Python objects that hold references to their own parameters, other modules, and methods that apply functions on user inputs.

hk.transform turns functions that use these object-oriented, functionally "impure" modules into pure functions that can be used with jax.jit, jax.grad, jax.pmap, etc.

Why Haiku?

There are a number of neural network libraries for JAX. Why should you choose Haiku?

Haiku has been tested by researchers at DeepMind at scale.

  • DeepMind has reproduced a number of experiments in Haiku and JAX with relative ease. These include large-scale results in image and language processing, generative models, and reinforcement learning.

Haiku is a library, not a framework.

  • Haiku is designed to make specific things simpler: managing model parameters and other model state.
  • Haiku can be expected to compose with other libraries and work well with the rest of JAX.
  • Haiku otherwise is designed to get out of your way - it does not define custom optimizers, checkpointing formats, or replication APIs.

Haiku does not reinvent the wheel.

  • Haiku builds on the programming model and APIs of Sonnet, a neural network library with near universal adoption at DeepMind. It preserves Sonnet's Module-based programming model for state management while retaining access to JAX's function transformations.
  • Haiku APIs and abstractions are as close as reasonable to Sonnet. Many users have found Sonnet to be a productive programming model in TensorFlow; Haiku enables the same experience in JAX.

Transitioning to Haiku is easy.

  • By design, transitioning from TensorFlow and Sonnet to JAX and Haiku is easy.
  • Outside of new features (e.g. hk.transform), Haiku aims to match the API of Sonnet 2. Modules, methods, argument names, defaults, and initialization schemes should match.

Haiku makes other aspects of JAX simpler.

  • Haiku offers a trivial model for working with random numbers. Within a transformed function, hk.next_rng_key() returns a unique rng key.
  • These unique keys are deterministically derived from an initial random key passed into the top-level transformed function, and are thus safe to use with JAX program transformations.

Quickstart

Let's take a look at an example neural network, loss function, and training loop. (For more examples, see our examples directory. The MNIST example is a good place to start.)

import haiku as hk
import jax.numpy as jnp

def softmax_cross_entropy(logits, labels):
  one_hot = jax.nn.one_hot(labels, logits.shape[-1])
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

def loss_fn(images, labels):
  mlp = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  logits = mlp(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)

def update_rule(param, update):
  return param - 0.01 * update

for images, labels in input_dataset:
  grads = jax.grad(loss_fn_t.apply)(params, images, labels)
  params = jax.tree_util.tree_map(update_rule, params, grads)

The core of Haiku is hk.transform. The transform function allows you to write neural network functions that rely on parameters (here the weights of the Linear layers) without requiring you to explicitly write the boilerplate for initialising those parameters. transform does this by transforming the function into a pair of functions that are pure (as required by JAX) init and apply.

init

The init function, with signature params = init(rng, ...) (where ... are the arguments to the untransformed function), allows you to collect the initial value of any parameters in the network. Haiku does this by running your function, keeping track of any parameters requested through hk.get_parameter (called by e.g. hk.Linear) and returning them to you.

The params object returned is a nested data structure of all the parameters in your network, designed for you to inspect and manipulate. Concretely, it is a mapping of module name to module parameters, where a module parameter is a mapping of parameter name to parameter value. For example:

{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
            'w': ndarray(..., shape=(28, 300), dtype=float32)},
 'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
              'w': ndarray(..., shape=(1000, 100), dtype=float32)},
 'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
              'w': ndarray(..., shape=(100, 10), dtype=float32)}}

apply

The apply function, with signature result = apply(params, rng, ...), allows you to inject parameter values into your function. Whenever hk.get_parameter is called, the value returned will come from the params you provide as input to apply:

loss = loss_fn_t.apply(params, rng, images, labels)

Note that since the actual computation performed by our loss function doesn't rely on random numbers, passing in a random number generator is unnecessary, so we could also pass in None for the rng argument. (Note that if your computation does use random numbers, passing in None for rng will cause an error to be raised.) In our example above, we ask Haiku to do this for us automatically with:

loss_fn_t = hk.without_apply_rng(loss_fn_t)

Since apply is a pure function we can pass it to jax.grad (or any of JAX's other transforms):

grads = jax.grad(loss_fn_t.apply)(params, images, labels)

Training

The training loop in this example is very simple. One detail to note is the use of jax.tree_util.tree_map to apply the sgd function across all matching entries in params and grads. The result has the same structure as the previous params and can again be used with apply.

Installation

Haiku is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install Haiku using pip:

$ pip install git+https://github.com/deepmind/dm-haiku

Alternatively, you can install via PyPI:

$ pip install -U dm-haiku

Our examples rely on additional libraries (e.g. bsuite). You can install the full set of additional requirements using pip:

$ pip install -r examples/requirements.txt

User manual

Writing your own modules

In Haiku, all modules are a subclass of hk.Module. You can implement any method you like (nothing is special-cased), but typically modules implement __init__ and __call__.

Let's work through implementing a linear layer:

class MyLinear(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
    return jnp.dot(x, w) + b

All modules have a name. When no name argument is passed to the module, its name is inferred from the name of the Python class (for example MyLinear becomes my_linear). Modules can have named parameters that are accessed using hk.get_parameter(param_name, ...). We use this API (rather than just using object properties) so that we can convert your code into a pure function using hk.transform.

When using modules you need to define functions and transform them into a pair of pure functions using hk.transform. See our quickstart for more details about the functions returned from transform:

def forward_fn(x):
  model = MyLinear(10)
  return model(x)

# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk.transform(forward_fn)

x = jnp.ones([1, 1])

# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)

# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument.  Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward.apply(params, None, x)

Working with stochastic models

Some models may require random sampling as part of the computation. For example, in variational autoencoders with the reparametrization trick, a random sample from the standard normal distribution is needed. For dropout we need a random mask to drop units from the input. The main hurdle in making this work with JAX is in management of PRNG keys.

In Haiku we provide a simple API for maintaining a PRNG key sequence associated with modules: hk.next_rng_key() (or next_rng_keys() for multiple keys):

class MyDropout(hk.Module):

  def __init__(self, rate=0.5, name=None):
    super().__init__(name=name)
    self.rate = rate

  def __call__(self, x):
    key = hk.next_rng_key()
    p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape)
    return x * p / (1.0 - self.rate)

forward = hk.transform(lambda x: MyDropout()(x))

key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2)
params = forward.init(key1, x)
prediction = forward.apply(params, key2, x)

For a more complete look at working with stochastic models, please see our VAE example.

Note: hk.next_rng_key() is not functionally pure which means you should avoid using it alongside JAX transformations which are inside hk.transform. For more information and possible workarounds, please consult the docs on Haiku transforms and available wrappers for JAX transforms inside Haiku networks.

Working with non-trainable state

Some models may want to maintain some internal, mutable state. For example, in batch normalization a moving average of values encountered during training is maintained.

In Haiku we provide a simple API for maintaining mutable state that is associated with modules: hk.set_state and hk.get_state. When using these functions you need to transform your function using hk.transform_with_state since the signature of the returned pair of functions is different:

def forward(x, is_training):
  net = hk.nets.ResNet50(1000)
  return net(x, is_training)

forward = hk.transform_with_state(forward)

# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params, state = forward.init(rng, x, is_training=True)

# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits, state = forward.apply(params, state, rng, x, is_training=True)

If you forget to use hk.transform_with_state don't worry, we will print a clear error pointing you to hk.transform_with_state rather than silently dropping your state.

Distributed training with jax.pmap

The pure functions returned from hk.transform (or hk.transform_with_state) are fully compatible with jax.pmap. For more details on SPMD programming with jax.pmap, look here.

One common use of jax.pmap with Haiku is for data-parallel training on many accelerators, potentially across multiple hosts. With Haiku, that might look like this:

def loss_fn(inputs, labels):
  logits = hk.nets.MLP([8, 4, 2])(x)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

# Initialize the model on a single device.
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_fn_t.init(rng, sample_image, sample_label)

# Replicate params onto all devices.
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)

def make_superbatch():
  """Constructs a superbatch, i.e. one batch of data per device."""
  # Get N batches, then split into list-of-images and list-of-labels.
  superbatch = [next(input_dataset) for _ in range(num_devices)]
  superbatch_images, superbatch_labels = zip(*superbatch)
  # Stack the superbatches to be one array with a leading dimension, rather than
  # a python list. This is what `jax.pmap` expects as input.
  superbatch_images = np.stack(superbatch_images)
  superbatch_labels = np.stack(superbatch_labels)
  return superbatch_images, superbatch_labels

def update(params, inputs, labels, axis_name='i'):
  """Updates params based on performance on inputs and labels."""
  grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)
  # Take the mean of the gradients across all data-parallel replicas.
  grads = jax.lax.pmean(grads, axis_name)
  # Update parameters using SGD or Adam or ...
  new_params = my_update_rule(params, grads)
  return new_params

# Run several training updates.
for _ in range(10):
  superbatch_images, superbatch_labels = make_superbatch()
  params = jax.pmap(update, axis_name='i')(params, superbatch_images,
                                           superbatch_labels)

For a more complete look at distributed Haiku training, take a look at our ResNet-50 on ImageNet example.

Citing Haiku

To cite this repository:

@software{haiku2020github,
  author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
  title = {{H}aiku: {S}onnet for {JAX}},
  url = {http://github.com/deepmind/dm-haiku},
  version = {0.0.10},
  year = {2020},
}

In this bibtex entry, the version number is intended to be from haiku/__init__.py, and the year corresponds to the project's open-source release.

More Repositories

1

deepmind-research

This repository contains implementations and illustrative code to accompany DeepMind publications
Jupyter Notebook
12,817
star
2

alphafold

Open source code for AlphaFold.
Python
11,700
star
3

sonnet

TensorFlow-based neural network library
Python
9,691
star
4

pysc2

StarCraft II Learning Environment
Python
7,904
star
5

mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
Jupyter Notebook
7,202
star
6

lab

A customisable 3D platform for agent-based AI research
C
7,012
star
7

graph_nets

Build Graph Nets in Tensorflow
Python
5,325
star
8

graphcast

Python
4,242
star
9

learning-to-learn

Learning to Learn in TensorFlow
Python
4,063
star
10

open_spiel

OpenSpiel is a collection of environments and algorithms for research in general reinforcement learning and search/planning in games.
C++
4,019
star
11

alphageometry

Python
3,580
star
12

dm_control

Google DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo.
Python
3,473
star
13

acme

A library of reinforcement learning components and agents
Python
3,372
star
14

trfl

TensorFlow Reinforcement Learning
Python
3,139
star
15

alphatensor

Python
2,616
star
16

dnc

A TensorFlow implementation of the Differentiable Neural Computer.
Python
2,478
star
17

mctx

Monte Carlo tree search in JAX
Python
2,209
star
18

gemma

Open weights LLM from Google DeepMind.
Jupyter Notebook
2,061
star
19

code_contests

C++
2,010
star
20

kinetics-i3d

Convolutional neural network model for video classification trained on the Kinetics dataset.
Python
1,639
star
21

mathematics_dataset

This dataset code generates mathematical question and answer pairs, from a range of question types at roughly school-level difficulty.
Python
1,621
star
22

optax

Optax is a gradient processing and optimization library for JAX.
Python
1,492
star
23

bsuite

bsuite is a collection of carefully-designed experiments that investigate core capabilities of a reinforcement learning (RL) agent
Python
1,465
star
24

penzai

A JAX research toolkit for building, editing, and visualizing neural networks.
Python
1,405
star
25

educational

Jupyter Notebook
1,382
star
26

jraph

A Graph Neural Network Library in Jax
Python
1,306
star
27

rc-data

Question answering dataset featured in "Teaching Machines to Read and Comprehend
Python
1,285
star
28

rlax

Python
1,185
star
29

tapnet

Tracking Any Point (TAP)
Python
1,033
star
30

scalable_agent

A TensorFlow implementation of Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures.
Python
972
star
31

neural-processes

This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).
Jupyter Notebook
966
star
32

android_env

RL research on Android devices.
Python
946
star
33

mujoco_menagerie

A collection of high-quality models for the MuJoCo physics engine, curated by Google DeepMind.
Jupyter Notebook
926
star
34

dramatron

Dramatron uses large language models to generate coherent scripts and screenplays.
Jupyter Notebook
904
star
35

tree

tree is a library for working with nested data structures
Python
891
star
36

xmanager

A platform for managing machine learning experiments
Python
794
star
37

mujoco_mpc

Real-time behaviour synthesis with MuJoCo, using Predictive Control
C++
771
star
38

materials_discovery

Python
770
star
39

chex

Python
716
star
40

reverb

Reverb is an efficient and easy-to-use data storage and transport system designed for machine learning research
C++
692
star
41

alphadev

Python
662
star
42

pycolab

A highly-customisable gridworld game engine with some batteries included. Make your own gridworld games to test reinforcement learning agents!
Python
654
star
43

ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Python
643
star
44

hanabi-learning-environment

hanabi_learning_environment is a research platform for Hanabi experiments.
Python
614
star
45

funsearch

Jupyter Notebook
611
star
46

ai-safety-gridworlds

This is a suite of reinforcement learning environments illustrating various safety properties of intelligent agents.
Python
577
star
47

dqn

Lua/Torch implementation of DQN (Nature, 2015)
Lua
546
star
48

ithaca

Restoring and attributing ancient texts using deep neural networks
Jupyter Notebook
540
star
49

meltingpot

A suite of test scenarios for multi-agent reinforcement learning.
Python
516
star
50

distrax

Python
509
star
51

recurrentgemma

Open weights language model from Google DeepMind, based on Griffin.
Python
505
star
52

surface-distance

Library to compute surface distance based performance metrics for segmentation tasks.
Python
493
star
53

tracr

Python
467
star
54

dsprites-dataset

Dataset to assess the disentanglement properties of unsupervised learning methods
Jupyter Notebook
463
star
55

alphamissense

Python
455
star
56

narrativeqa

This repository contains the NarrativeQA dataset. It includes the list of documents with Wikipedia summaries, links to full stories, and questions and answers.
Shell
432
star
57

lab2d

A customisable 2D platform for agent-based AI research
C++
415
star
58

open_x_embodiment

Jupyter Notebook
409
star
59

dqn_zoo

DQN Zoo is a collection of reference implementations of reinforcement learning agents developed at DeepMind based on the Deep Q-Network (DQN) agent.
Python
406
star
60

clrs

Python
376
star
61

spriteworld

Spriteworld: a flexible, configurable python-based reinforcement learning environment
Python
366
star
62

dm_pix

PIX is an image processing library in JAX, for JAX.
Python
363
star
63

concordia

A library for generative social simulation
Python
351
star
64

mathematics_conjectures

Jupyter Notebook
348
star
65

alphastar

Python
346
star
66

spiral

We provide a pre-trained model for unconditional 19-step generation of CelebA-HQ images
C++
327
star
67

dm_env

A Python interface for reinforcement learning environments
Python
326
star
68

dm_robotics

Libraries, tools and tasks created and used at DeepMind Robotics.
Python
315
star
69

uncertain_ground_truth

Dermatology ddx dataset, Jax implementations of Monte Carlo conformal prediction, plausibility regions and statistical annotation aggregation from our recent work on uncertain ground truth (TMLR'23 and ArXiv pre-print).
Python
315
star
70

long-form-factuality

Benchmarking long-form factuality in large language models. Original code for our paper "Long-form factuality in large language models".
Python
314
star
71

launchpad

Python
305
star
72

leo

Implementation of Meta-Learning with Latent Embedding Optimization
Python
302
star
73

streetlearn

A C++/Python implementation of the StreetLearn environment based on images from Street View, as well as a TensorFlow implementation of goal-driven navigation agents solving the task published in “Learning to Navigate in Cities Without a Map”, NeurIPS 2018
C++
279
star
74

gqn-datasets

Datasets used to train Generative Query Networks (GQNs) in the ‘Neural Scene Representation and Rendering’ paper.
Python
269
star
75

enn

Python
265
star
76

multi_object_datasets

Multi-object image datasets with ground-truth segmentation masks and generative factors.
Python
247
star
77

AQuA

A algebraic word problem dataset, with multiple choice questions annotated with rationales.
238
star
78

card2code

A code generation dataset for generating the code that implements Hearthstone and Magic The Gathering card effects.
236
star
79

grid-cells

Implementation of the supervised learning experiments in Vector-based navigation using grid-like representations in artificial agents, as published at https://www.nature.com/articles/s41586-018-0102-6
Python
236
star
80

arnheim

Jupyter Notebook
235
star
81

synjax

Python
233
star
82

torch-hdf5

Torch interface to HDF5 library
Lua
231
star
83

dm_memorytasks

A set of 13 diverse machine-learning tasks that require memory to solve.
Python
220
star
84

Temporal-3D-Pose-Kinetics

Exploiting temporal context for 3D human pose estimation in the wild: 3D poses for the Kinetics dataset
Python
214
star
85

opro

official code for "Large Language Models as Optimizers"
Python
199
star
86

dm_alchemy

DeepMind Alchemy task environment: a meta-reinforcement learning benchmark
Python
197
star
87

neural_testbed

Jupyter Notebook
187
star
88

kfac-jax

Second Order Optimization and Curvature Estimation with K-FAC in JAX.
Python
187
star
89

pg19

179
star
90

xquad

173
star
91

jmp

JMP is a Mixed Precision library for JAX.
Python
171
star
92

spectral_inference_networks

Implementation of Spectral Inference Networks, ICLR 2019
Python
165
star
93

abstract-reasoning-matrices

Progressive matrices dataset, as described in: Measuring abstract reasoning in neural networks (Barrett*, Hill*, Santoro*, Morcos, Lillicrap), ICML2018
162
star
94

xitari

This is the 0.4 release of the Arcade Learning Environment (ALE), a platform designed for AI research. ALE is based on Stella, an Atari 2600 VCS emulator.
C++
159
star
95

tensor_annotations

Annotating tensor shapes using Python types
Python
158
star
96

neural_networks_chomsky_hierarchy

Neural Networks and the Chomsky Hierarchy
Python
155
star
97

symplectic-gradient-adjustment

A colab that implements the Symplectic Gradient Adjustment optimizer from "The mechanics of n-player differentiable games"
Jupyter Notebook
150
star
98

mc_gradients

Jupyter Notebook
149
star
99

interval-bound-propagation

This repository contains a simple implementation of Interval Bound Propagation (IBP) using TensorFlow: https://arxiv.org/abs/1810.12715
Python
148
star
100

s6

C++
146
star