Learning Differential Equations that are Easy to Solve
Code for the paper:
Jacob Kelly*, Jesse Bettencourt*, Matthew James Johnson, David Duvenaud. "Learning Differential Equations that are Easy to Solve." Neural Information Processing Systems (2020). [arxiv] [bibtex]
*Equal Contribution
Includes JAX implementations of the following models:
- Neural ODEs for classification
- Latent ODEs for time series
- FFJORD for density estimation
Includes JAX implementations of the following adaptive-stepping numerical solvers:
- Heun-Euler
heun
(2nd order) - Fehlberg (RK1(2))
fehlberg
(2nd order) - Bogacki-Shampine
bosh
(3rd order) - Cash-Karp
cash_karp
(4th order) - Fehlberg
rk_fehlberg
(4th order) - Owrenzen
owrenzen
(4th order) - Dormand-Prince
dopri
(5th order) - Owrenzen
owrenzen5
(5th order) - Tanyam
tanyam
(7th order) - Adams
adams
(adaptive order) - RK4
rk4
(4th order, fixed step-size)
Requirements
Python
Please use python>=3.8
JAX
Follow installation instructions here.
Haiku
Follow installation instructions here.
Tensorflow Datasets
For using the MNIST dataset, follow installation instructions here.
Usage
Different scripts are provided for each task and dataset.
MNIST Classification
python mnist.py --reg r3 --lam 6e-5
Latent ODEs
python latent_ode.py --reg r3 --lam 1e-2
FFJORD (Tabular)
python ffjord_tabular.py --reg r2 --lam 1e-2
FFJORD (MNIST)
python ffjord_mnist.py --reg r2 --lam 3e-4
Datasets
MNIST
tensorflow-datasets
(instructions for installing above) will download the data when called from the training script.
Physionet
The file physionet_data.py
, adapted from Latent ODEs for Irregularly-Sampled Time Series will download and process the data when called from the training script. A preprocessed version is available in releases.
Tabular (FFJORD)
Data must be downloaded following instructions from gpapamak/maf and placed in data/
. Only MINIBOONE
is needed for experiments in the paper.
Code in datasets/
, adapted from Free-form Jacobian of Reversible Dynamics (FFJORD), will create an interface for the MINIBOONE
dataset once it's downloaded.
It is called from the training script.
Acknowledgements
Code in lib
is modified from google/jax under the license.
Several numerical solvers were adapted from torchdiffeq and DifferentialEquations.jl.
BibTeX
@inproceedings{kelly2020easynode,
title={Learning Differential Equations that are Easy to Solve},
author={Kelly, Jacob and Bettencourt, Jesse and Johnson, Matthew James and Duvenaud, David},
booktitle={Neural Information Processing Systems},
year={2020},
url={https://arxiv.org/abs/2007.04504}
}