• Stars
    star
    218
  • Rank 181,805 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 6 years ago
  • Updated about 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Tensorflow implementation of Ordinary Differential Equation Solvers with full GPU support

Tensorflow Ordinary Differential Equation Solvers

Build Status

A library built to replicate the TorchDiffEq library built for the Neural Ordinary Differential Equations paper by Chen et al, running entirely on Tensorflow Eager Execution.

All credits for the codebase go to @rtqichen for providing an excellent base to reimplement from.

Similar to the PyTorch codebase, this library provides ordinary differential equation (ODE) solvers implemented in Tensorflow Eager. For usage of ODE solvers in deep learning applications, see Neural Ordinary Differential Equations paper.

Supports Augmented Neural ODE Architectures from the paper Augmented Neural ODEs as well, which has been shown to solve certain problems that Neural ODEs may struggle with.

Support for Universal Differential Equations (for ODE case) from the paper Universal Differential Equations for Scientific Machine Learning. While slow, and restricted to ODEs only, it works well enough on Lotke Voltera system as described in example notebook.

Support for Hypersolvers from the paper Hypersolvers: Toward Fast Continuous-Depth Models. Currently implements the paper implementation of HyperEuler and HyperHuen. The paper Deep Euler method: solving ODEs by approximating the local truncation error of the Euler method proposes nearly the same approach. NOTE: These APIs are subject to change once the paper releases source code.

Now supports Adjoint methods for Dopri5 solver due to PR #3 from @eozd.

As the solvers are implemented in Tensorflow, algorithms in this repository fully support running on the GPU, and are differentiable. Also supports prebuilt ODENet and ConvODENet tf.keras Models that can be used as is or embedded in a larger architecture.

Caveats

There are a few major limitations with this project :

  • Speed is almost the same as the PyTorch codebase (+- 2%), if the solver is wrapped inside a tf.device block. Runge-Kutta solvers require double dtype precision for correct gradient computations. Yet, Tensorflow does not provide a convenient global switch to force all created tensors to double dtype. So explicit casts are unavoidable.
    • Make sure to wrap the entire script in a with tf.device('/gpu:0') to make full utilization of the GPU. Or select the main components - the model, the optimizer, the dataset and the odeint call inside tf.device blocks locally.
    • Convenience method move_to_device is made available from the library to make things easier on this front.
    • If type errors are thrown, correct them with tf.cast()

Notebooks to get started

  1. There exists a Jupyter Notebook in the examples folder, ode_usage.ipynb which has examples of several ODE solutions, explaining various methods and demonstrates visualization functions available in this library. The Notebook can also be visualized on Google Colab : Colaboratory Link
  1. An example of Augmented Neural ODEs and Prebuilt ODENet models is available on Google Colab : Colaboratory Link
  1. An example of Universal Differential Equations for the Lotka-Volterra system is available on Google Colab : Colaboratory Link

Basic Usage

Note: This is taken directly from the original PyTorch codebase. Almost all concepts apply here as well.

This library provides one main interface odeint which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem consists of an ODE and an initial value,

dy/dt = f(t, y)    y(t_0) = y_0.

The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition.

To solve an IVP using the default solver:

from tfdiffeq import odeint

odeint(func, y0, t)

where func is any callable implementing the ordinary differential equation f(t, x), y0 is an any-D Tensor or a tuple of any-D Tensors representing the initial values, and t is a 1-D Tensor containing the evaluation points. The initial time is taken to be t[0].

Backpropagation through odeint goes through the internals of the solver, but this is not supported for all solvers. Instead, we encourage the use of the adjoint method explained in Neural Ordinary Differential Equations paper, which will allow solving with as many steps as necessary due to O(1) memory usage.

Example of an ODE Model

import tensorflow as tf

class LotkaVolterra(tf.keras.Model):

  def __init__(self, a, b, c, d,):
    super().__init__()
    self.a, self.b, self.c, self.d = a, b, c, d

  @tf.function
  def call(self, t, y):
    # y = [R, F]
    r, f = tf.unstack(y)

    dR_dT = self.a * r - self.b * r * f
    dF_dT = -self.c * f + self.d * r * f

    return tf.stack([dR_dT, dF_dT])

Examples of using Hypersolvers

Please refer to the Colab Notebook - hyper_solvers, in order to see a demonstration of a HyperHuen network trained to evaluate Lorenz Attractor chaotic ODE.

This above notebook includes code blocks to be used as reference to train and evaluate such Hypersolvers.

import tensorflow as tf
from tfdiffeq.hyper import HyperHeun

# Create a Hyper network that HyperHuen will use as the Solver network `g`

class HyperSolverModule(tf.keras.Model):
  def __init__(self, func_input_dim, hidden_dim=64):
    super().__init__(dtype='float64')
    self.func_input_dim = func_input_dim
    
    # Input dim isnt used (Keras handled it automatically)
    # But for illustration purposes, it is provided
    # Computed as ~ dim(y) + dim(dy) + 1 (for time axis)
    self.input_dim = 2 * func_input_dim + 1
    self.hidden_dim = hidden_dim
    self.output_dim = func_input_dim

    self.g = tf.keras.Sequential([
        tf.keras.layers.Dense(self.hidden_dim),
        tf.keras.layers.PReLU(),
        tf.keras.layers.Dense(self.hidden_dim),
        tf.keras.layers.PReLU(),
        tf.keras.layers.Dense(self.hidden_dim),
        tf.keras.layers.PReLU(),
        tf.keras.layers.Dense(self.output_dim)
    ])
    
  @tf.function
  def call(self, x):
    return self.g(x)

# Assume we use Lorenz Attractor as the ODE we want to model - `f`

f = Lorenz(sigma, beta, rho)
g = HyperSolverModule(func_input_dim=3, hidden_dim=64)

hyper_heun = HyperHeun(f, g)

# The stepe to train this model are available in the notebook mentioned above

Prebuilt Models

This library now supports prebuilt models inside the tfdiffeq.models namespace - specifically the Neural ODENet and Convolutional Neural ODENet. In addition, both of these models inherently support Augmented Neural ODENets.

They can be used a models themselves, or can be inserted inside a larger stack of ODENet layers to build a deeper ODENet or ConvODENet model, depending on the usecase.

Usage :

import tensorflow as tf
from tfdiffeq.models import ODENet, ConvODENet

# Directly usable model
model = ODENet(hidden_dim, output_dim, augment_dim=0, time_dependent=False)
model = ConvODENet(num_filters, augment_dim=0, time_dependent=False)

# Used inside other models
x = Conv2D(...)(x)
x = Conv2D(...)(x)
x = Flatten()(x)
x = ODENet(...)(x)  # or dont use flatten and use ConvODENet directly
x = ODENet(...)(x)  # or dont use flatten and use ConvODENet directly
...

Keyword Arguments

  • rtol: Relative tolerance.
  • atol: Absolute tolerance.
  • method: One of the solvers listed below.

List of ODE Solvers:

Adaptive-step:

  • dopri5: Runge-Kutta 4(5) [default].
  • dopri8: Runga-Kutta 8(5).
  • adams: Adaptive-order implicit Adams.

Fixed-step:

  • euler: Euler method.
  • midpoint: Midpoint method.
  • huen: Second-order Runge-Kutta.
  • adaptive_heun: Second-order Adaptive Heun method.
  • bosh3: Bogacki-Shampine solver (MATLAB ode23).
  • rk4: Fourth-order Runge-Kutta with 3/8 rule.
  • explicit_adams: Explicit Adams.
  • fixed_adams: Implicit Adams

Hyper-solvers (experimental)

  • HyperEuler: Hyper Euler model.
  • HyperMidpoint: Hyper Midpoint model.
  • HyperHeun: Hyper Heun model.

Compatibility

Since tensorflow doesn't yet support global setting of default datatype, the tfdiffeq library provides a few convenience methods.

  • move_to_device : Attempts to move a tf.Tensor to a certain device. Can specify the device in the normal syntax of cpu:0 or gpu:x where x must be replaced by any number representing the GPU ID. Falls back to CPU if GPU is unavailable.

  • Dont forget to add a @tf.function on your call(self, t, u) methods defined in a Keras Models for some significant speed up in some cases !

Examples

The scripts for the examples can be found in the examples folder, along with the weights and results for the latent_ode.py script as it takes some time to train. Two results have been replicated from the original codebase:

  • ode_demo.py : A basic example which contains a short implementation of learning a dynamics model to mimic a spiral ODE.

The training should look similar to this:

ode spiral demo

  • circular_ode_demo.py : A basic example similar to above which contains a short implementation of learning a dynamics model to mimic a circular ODE.

The training should look similar to this:

ode circular demo

  • lorenz_attractor.py : A classic example of a chaotic solution for certain parameter sets and initial conditions.

Note this is just a stress test for the library, and scipy.integrate.odefun can solve this much much faster due to highly optimized routines. This should take roughly 1 minute on a modern machine.

lorenz attractor

  • latent_ode.py : Another basic example which uses variational inference to learn a path along a spiral.

Results should be similar to below after 1200 iterations:

ode spiral latent

  • ODENet on MNIST

While the Adjoint method is not yet implemented, a smaller version of ODENet can be easily trained using the fixed grid solvers - Euler or Huens for a fast approximate solution. It has been observed that as MNIST is an extremely easy problem, RK45 (DOPRI5) works relatively well, whereas on more complex datasets like CIFAR 10/100 it diverges in the first epoch.

Reference : ANODE: Unconditionally Accurate Memory-Efficient Gradients for Neural ODEs

Universal ODE

  • Universal Differential Equations

Following the methodology in the paper Universal Differential Equations for Scientific Machine Learning, we reproduce (sub-optimally) the Lotke-Volterra experiment in the following notebook - UniversalNeuralODE.ipynb

References : Universal Differential Equations for Scientific Machine Learning

  • Continious Normalizing Flows

Ported Continious Normalizing Flow example from the torchdiffeq repository - CNF Examples.

References : FFJORD: Free-Form Continuous Dynamics for Scalable Reversible Generative Models

Hypersolvers

Example of constructing, training and evaluating Hypersolver networks as described in the paper Hypersolvers: Toward Fast Continuous-Depth Models. NOTE: Current API is experiemental and subject to change when the paper releases its code.

References : Hypersolvers: Toward Fast Continuous-Depth Models

Reference

If you found this library useful in your research, please consider citing

@article{chen2018neural,
  title={Neural Ordinary Differential Equations},
  author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David},
  journal={Advances in Neural Information Processing Systems},
  year={2018}
}

Requirements

Install the required Tensorflow version along with the package using either

pip install .[tf]  # for cpu
pip install .[tf-gpu]  # for gpu
pip install .[tests]  # for cpu testing
  • Tensorflow TF 2 / 1.15.0 or above. Prefereably TF 2.0 when it comes out, as the entire codebase requires Eager Execution.
  • Tensorflow Probability (for CNF example only)
  • matplotlib
  • numpy
  • scipy (for tests)
  • six
  • pysindy (for Universal Differential Equations support only)

More Repositories

1

Neural-Style-Transfer

Keras Implementation of Neural Style Transfer from the paper "A Neural Algorithm of Artistic Style" (http://arxiv.org/abs/1508.06576) in Keras 2.0+
Jupyter Notebook
2,271
star
2

Image-Super-Resolution

Implementation of Super Resolution CNN in Keras.
Python
832
star
3

neural-image-assessment

Implementation of NIMA: Neural Image Assessment in Keras
Python
780
star
4

LSTM-FCN

Codebase for the paper LSTM Fully Convolutional Networks for Time Series Classification
Python
755
star
5

DenseNet

DenseNet implementation in Keras
Python
706
star
6

MLSTM-FCN

Multivariate LSTM Fully Convolutional Networks for Time Series Classification
Python
490
star
7

neural-architecture-search

Basic implementation of [Neural Architecture Search with Reinforcement Learning](https://arxiv.org/abs/1611.01578).
Python
431
star
8

keras-squeeze-excite-network

Implementation of Squeeze and Excitation Networks in Keras
Python
400
star
9

Inception-v4

Inception-v4, Inception - Resnet-v1 and v2 Architectures in Keras
Python
385
star
10

Keras-Classification-Models

Collection of Keras models used for classification
Python
317
star
11

Snapshot-Ensembles

Snapshot Ensemble in Keras
Python
305
star
12

keras-non-local-nets

Keras implementation of Non-local Neural Networks
Python
290
star
13

keras-one-cycle

Implementation of One-Cycle Learning rate policy (adapted from Fast.ai lib)
Python
285
star
14

Super-Resolution-using-Generative-Adversarial-Networks

An implementation of SRGAN model in Keras
Python
283
star
15

tf-TabNet

A Tensorflow 2.0 implementation of TabNet.
Python
238
star
16

Keras-ResNeXt

Implementation of ResNeXt models from the paper Aggregated Residual Transformations for Deep Neural Networks in Keras 2.0+.
Python
224
star
17

Keras-NASNet

"NASNet" models in Keras 2.0+ with weights
Python
200
star
18

keras-efficientnets

Keras Implementation of EfficientNets
Python
187
star
19

tf_SIREN

Tensorflow 2.0 implementation of Sinusodial Representation networks (SIREN)
Python
149
star
20

keras-coordconv

Keras implementation of CoordConv for all Convolution layers
Python
148
star
21

MobileNetworks

Keras implementation of Mobile Networks
Python
132
star
22

keras-adabound

Keras implementation of AdaBound
Python
130
star
23

progressive-neural-architecture-search

Implementation of Progressive Neural Architecture Search in Keras and Tensorflow
Python
120
star
24

keras-attention-augmented-convs

Keras implementation of Attention Augmented Convolutional Neural Networks
Python
120
star
25

Keras-DualPathNetworks

Dual Path Networks for Keras 2.0+
Python
114
star
26

Wide-Residual-Networks

Wide Residual Networks in Keras
Python
112
star
27

Fast-Neural-Style

Implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" in Keras
Python
109
star
28

Keras-Group-Normalization

A Keras implementation of https://arxiv.org/abs/1803.08494
Python
103
star
29

BatchRenormalization

Batch Renormalization algorithm implementation in Keras
Python
98
star
30

Nested-LSTM

Keras implementation of Nested LSTMs
Python
90
star
31

keras-SRU

Implementation of Simple Recurrent Unit in Keras
Python
89
star
32

Fully-Connected-DenseNets-Semantic-Segmentation

Fully Connected DenseNet for Image Segmentation (https://arxiv.org/pdf/1611.09326v1.pdf)
Python
84
star
33

keras-LAMB-Optimizer

Implementation of the LAMB optimizer for Keras from the paper "Reducing BERT Pre-Training Time from 3 Days to 76 Minutes"
Python
76
star
34

tf-eager-examples

A set of simple examples ported from PyTorch for Tensorflow Eager Execution
Jupyter Notebook
73
star
35

keras_rectified_adam

Implementation of Rectified Adam in Keras
Python
69
star
36

Keras-IndRNN

Implementation of IndRNN in Keras
Python
67
star
37

LSTM-FCN-Ablation

Repository for the ablation study of "Long Short-Term Memory Fully Convolutional Networks for Time Series Classification"
Python
55
star
38

keras-octconv

Keras implementation of Octave Convolutions
Python
53
star
39

keras-global-context-networks

Keras implementation of Global Context Attention blocks
Python
46
star
40

Neural-Style-Transfer-Windows

Windows Form application written in C# to ease usage of neural style transfer script
Python
43
star
41

tf_fourier_features

Tensorflow 2.0 implementation of Fourier Feature Mapping Networks.
Python
42
star
42

Keras-Multiplicative-LSTM

Miltiplicative LSTM for Keras 2.0+
Python
42
star
43

keras_mixnets

Keras Implementation of MixNets: Mixed Depthwise Convolutions
Python
39
star
44

Keras-just-another-network-JANET

Keras implementation of [The unreasonable effectiveness of the forget gate](https://arxiv.org/abs/1804.04849)
Jupyter Notebook
35
star
45

keras-switchnorm

Switch Normalization implementation for Keras 2+
Python
30
star
46

keras-neural-alu

A Keras implementation of Neural Arithmatic and Logical Unit
Python
27
star
47

keras-mobile-colorizer

U-Net Model conditioned with MobileNet features for Grayscale -> Color mapping
Python
25
star
48

Deep-Columnar-Convolutional-Neural-Network

Deep Columnar Convolutional Neural Network architecture, which is based on Multi Columnar DNN (Ciresan 2012).
Python
24
star
49

keras-SparseNet

Keras Implementation of SparseNets
Python
23
star
50

Residual-of-Residual-Networks

Residual Network of Residual Networks in Keras
Python
22
star
51

pyshac

A Python library for the Sequential Halving and Classification algorithm
Python
21
star
52

keras_novograd

Keras implementation of NovoGrad
Python
20
star
53

Adversarial-Attacks-Time-Series

Codebase for the paper "Adversarial Attacks on Time Series"
Python
20
star
54

simple_diffusion

Simple notebooks to learn diffusion models on toy datasets
Jupyter Notebook
17
star
55

keras-normalized-optimizers

Wrapper for Normalized Gradient Descent in Keras
Jupyter Notebook
17
star
56

keras-padam

Keras implementation of Padam from "Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks"
Python
17
star
57

pytorch_odegan

Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations
Python
16
star
58

tf-sha-rnn

Tensorflow port implementation of Single Headed Attention RNN
Python
16
star
59

warprnnt_numba

WarpRNNT loss ported in Numba CPU/CUDA for Pytorch
Jupyter Notebook
16
star
60

Advanced_Machine_Learning

Python
16
star
61

dtw-numba

Implementation of Dynamic Time Warping algorithm with speed improvements based on Numba.
Python
16
star
62

keras-minimal-rnn

Keras implementation of MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks
Python
16
star
63

TweetSentimentAnalysis

CS583 course project
Python
14
star
64

lambda_networks_pt

Lambda Networks implemented in PyTorch
Python
13
star
65

tf_GON

Tensorflow 2.x implementation of Gradient Origin Networks
Python
13
star
66

tf_neural_deconvolution

Neural Deconvolutions in Tensorflow
Python
12
star
67

Python-Work

Python scripts to facilitate easy working
Jupyter Notebook
11
star
68

PyCTakesParser

Utilities to parse the output of cTAKES
Python
10
star
69

tf_star_rnn

Tensorflow 2.0 implementation of STAR RNN
Python
10
star
70

Deep-Dream

Deep Dream implementation in Keras
Python
9
star
71

Kaggle

Kaggle competition library. Uses Python 3.4.1 with almost all known python libraries for Machine Learning
Python
7
star
72

Music-Recognition

C# project to perform Frequency Analysis of music
C#
5
star
73

Rabin-Karp-String-Matching

C
4
star
74

Data-Science

Library of Data Science classes
Python
3
star
75

diffusion_model_nemo

Python
3
star
76

Ragial-Searcher

The Core Java library used to parse and store Ragial.com data
HTML
3
star
77

MSApriori

Multiple support apriori algorithm in Java
Java
3
star
78

RagialNotifier

Android App to parse ragial.com using the Ragial Searcher library to track items and notify the user if the item is on sale. Developed for the game Ragnarok Online, developed and owned by Gravity Inc.
Java
3
star
79

IDS-Course-Project

Intro to Data Science Project
Python
2
star
80

ML-Tools

Python
2
star
81

braindrain-uncommonhacks

JavaScript
2
star
82

Tiger-Game

Tiger Game in Python 2.7 / 3.4+
Python
2
star
83

8086-Microprocessor

An attempt to emulate an 8086 microprocessor, with its ASM instruction set.
Java
2
star
84

titu1994.github.io

HTML
2
star
85

Adaptive-Sorting-Algorithm

Analysis and implementation of Machine Learning Decision Tree to classify best algorithm for given data set
C#
2
star
86

Optimal-Binary-Search-Tree

C
2
star
87

Naive-String-Matching

C
2
star
88

Recurstion-C

Recursion in C
C
2
star
89

Java-Adaptive-Sorting-Algorithm

Adaptive Sorting Algorithm using Decision Trees to decide which algorithm will be optimal to sort a given dataset.
Java
2
star
90

Quick-Sort

Quick Sort in Java
1
star
91

Rate-Monotonic-Scheduling-Algorithm

Java
1
star
92

WT-Mini-Project

CSS
1
star
93

Kruskals-Algorithm

C
1
star
94

Stack

Stack
C
1
star
95

Doublu-Linked-List

Doubly Linked List
C
1
star
96

CircularLinkedList

Circular Linked List in C
C
1
star
97

Knuth-Morris-Pratt

C
1
star
98

MyLib

1
star
99

Polynomial-Linked-List

Polynomial Linked List
C
1
star
100

SOOAD-Mini-Project

Java
1
star