• Stars
    star
    2,029
  • Rank 22,818 (Top 0.5 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 3 years ago
  • Updated 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

Equinox

Equinox is a JAX library for parameterised functions (e.g. neural networks) offering:

  • a PyTorch-like API...
  • ...that's fully compatible with native JAX transformations...
  • ...with no new concepts you have to learn.

If you're completely new to JAX, then start with this CNN on MNIST example.
If you're already familiar with JAX, then the main idea is to represent parameterised functions (such as neural networks) as PyTrees, so that they can pass across JIT/grad/etc. boundaries smoothly.

The elegance of Equinox is its selling point in a world that already has Haiku, Flax and so on.

In other words, why should you care? Because Equinox is really simple to learn, and really simple to use.

Installation

pip install equinox

Requires Python 3.9+ and JAX 0.4.11+.

Documentation

Available at https://docs.kidger.site/equinox.

Quick example

Models are defined using PyTorch-like syntax:

import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

and fully compatible with normal JAX operations:

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)

Finally, there's no magic behind the scenes. All eqx.Module does is register your class as a PyTree. From that point onwards, JAX already knows how to work with PyTrees.

Citation

If you found this library to be useful in academic work, then please cite: (arXiv link)

@article{kidger2021equinox,
    author={Patrick Kidger and Cristian Garcia},
    title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
    year={2021},
    journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

(Also consider starring the project on GitHub.)

See also: other libraries in the JAX ecosystem

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Lineax: linear solvers and linear least squares.

jaxtyping: type annotations for shape/dtype of arrays.

Eqxvision: computer vision models.

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

More Repositories

1

torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Python
1,380
star
2

diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Python
1,369
star
3

jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Python
1,117
star
4

NeuralCDE

Code for "Neural Controlled Differential Equations for Irregular Time Series" (Neurips 2020 Spotlight)
Python
610
star
5

torchcde

Differentiable controlled differential equation solvers for PyTorch with GPU support and memory-efficient adjoint backpropagation.
Python
411
star
6

lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Python
344
star
7

mkposters

Make posters from Markdown files.
Python
324
star
8

sympy2jax

Turn SymPy expressions into trainable JAX expressions.
Python
313
star
9

optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Python
299
star
10

signatory

Differentiable computations of the signature and logsignature transforms, on both CPU and GPU. (ICLR 2021)
C++
258
star
11

torchcubicspline

Interpolating natural cubic splines. Includes batching, GPU support, support for missing values, evaluating derivatives of the spline, and backpropagation.
Python
214
star
12

sympytorch

Turning SymPy expressions into PyTorch modules.
Python
139
star
13

quax

Multiple dispatch over abstract array types in JAX.
Python
100
star
14

Deep-Signature-Transforms

Code for "Deep Signature Transforms" (NeurIPS 2019)
Jupyter Notebook
87
star
15

FasterNeuralDiffEq

Code for "'Hey, that's not an ODE:' Faster ODE Adjoints via Seminorms" (ICML 2021)
Python
86
star
16

typst_pyimage

Typst extension, adding support for generating figures using inline Python code
Python
72
star
17

generalised_shapelets

Code for "Generalised Interpretable Shapelets for Irregular Time Series"
Jupyter Notebook
52
star
18

PatModules.jl

A better import/module system for Julia.
Julia
18
star
19

exvoker

A CLI tool. Extract regexes from stdout (e.g. URLs) and invoke commands on them (e.g. open the webpage).
Rust
9
star
20

action_update_python_project

Github Action to: Check version / Test / git tag / GitHub Release / Deploy to PyPI
8
star
21

pytkdocs_tweaks

Some custom tweaks to the results produced by pytkdocs.
Python
5
star
22

Learning-Interpolation

Applying machine learning to help numerically solve the Camassa-Holm equation.
Jupyter Notebook
4
star
23

matching

Round robin matching algorithm.
Python
3
star
24

candle

Simple PyTorch helpers. (I think we've probably all written one of these for ourselves!)
Python
3
star
25

tools

Helpful abstract tools (functions, classes, ... ) for coding in Python.
Python
3
star
26

pdfscraper

Saves a webpage and all linked pdfs.
Python
3
star
27

ktools

Tools for working with Keras.
Python
2
star
28

loccounter

Counts lines of Python code.
Python
2
star
29

Dissertation

Master's Dissertation: Polynomial Approximation of Holomorphic Functions
2
star
30

py2annotate

An extension to Sphinx autodoc to augment Sphinx documentation with type annotations, when using Python 2 style type annotations.
Python
2
star
31

adventuregame

The very start of a game I was toying with before I got distracted by the PhD...
Python
2
star
32

MPE-CDT-Project

A simple machine learning project for weather observations.
Jupyter Notebook
2
star
33

tfext

Some extra stuff for using with TensorFlow.
Python
2
star
34

mkdocs_include_exclude_files

Modify which files MkDocs includes or excludes.
Python
1
star
35

patrick-kidger

1
star
36

rl-test

Python
1
star