• Stars
    star
    719
  • Rank 62,985 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 5 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[NeurIPS'19] Deep Equilibrium Models

Deep Equilibrium Models

(Version 2.0 released now! 😀)

News

💥2021/6: Repo updated with the multiscale DEQ (MDEQ) code, Jacobian-related analysis & regularization support, and the new, faster and simpler implicit differentiation implementation through PyTorch's backward hook! (See here.)

  • For those who would like to start with a toy version of the DEQ, the NeurIPS 2020 tutorial on "Deep Implicit Layers" has a detailed step-by-step introduction: tutorial video & colab notebooks here.

  • A JAX version of the DEQ, including JAX implementation of Broyden's method, etc. is available here.


This repository contains the code for the deep equilibrium (DEQ) model, an implicit-depth architecture that directly solves for and backpropagtes through the (fixed-point) equilibrium state of an (effectively) infinitely deep network. Importantly, compared to prior implicit-depth approaches (e.g., ODE-based methods), in this work we also demonstrate the potential power and compatibility of this implicit model with modern, structured layers like Transformers, which enable the DEQ networks to achieve results on par with the SOTA deep networks (in NLP and vision) without using a "deep" stacking (and thus O(1) memory). Moreover, we also provide tools for regularizing the stability of these implicit models.

Specifically, this repo contains the code from the following papers (see bibtex at the end of this README):

Prerequisite

Python >= 3.6 and PyTorch >= 1.10. 4 GPUs strongly recommended for computational efficiency.

Data

We provide more detailed instructions for downloading/processing the datasets (WikiText-103, ImageNet, Cityscapes, etc.) in the DEQ-Sequence/ and MDEQ-Vision/ subfolders.

How to build/train a DEQ model?

Starting in 2021/6, we partition the repo into two sections, containing the sequence-model DEQ (i.e., DEQ-Sequence/) and the vision-model DEQ (i.e., MDEQ-Vision/) networks, respectively. As these two tasks require different input processing and loss objectives, they do not directly share the training framework.

However, both frameworks share the same utility code, such as:

  • lib/solvers.py: Advanced fixed-point solvers (e.g., Anderson acceleration and Broyden's method)
  • lib/jacobian.py: Jacobian-related estimations (e.g., Hutchinson estimator and the Power method)
  • lib/optimization.py: Regularizations (e.g., weight normalization and variational dropout)
  • lib/layer_utils.py: Layer utilities

Moreover, the repo is significantly simplified from the previous version for users to extend on it. In particular,

Theorem 2 (Universality of "single-layer" DEQs, very informal): Stacking multiple DEQs (with potentially different classes of transformations) does not create extra representational power over a single DEQ.

(See the paper for a formal statement.) By the theorem above, designing a better DEQ model boils down to designing a better stable transformation f_\theta. Creating and playing with a DEQ is easy, and we recommend following 3 steps (which we adopt in this repo):

Step 1: Defining a layer f=f_\theta that we'd like to iterate until equilibrium.

Typically, this is just like any deep network layer, and should be a subclass of torch.nn.Module. Evaluating this layer requires the hidden unit z and the input injection x; e.g.:

class Layer(nn.Module):
    def __init__(self, ...):
	...
    def forward(self, z, x, **kwargs):
        return new_z

Step 2: Prepare the fixed point solver to use for the DEQ model.

As a DEQ model can use any black-box root solver. We provide PyTorch fixed-point solver implementations anderson(...) and broyden(...) in lib/solvers.py that output a dictionary containing the basic information of the optimization process. By default, we use the relative residual difference (i.e., |f(z)-z|/|z|) as the criterion for stopping the iterative process.

The forward pass can then be reduced to 2 lines:

with torch.no_grad():
    # x is the input injection; z0 is the initial estimate of the fixed point.
    z_star = self.solver(lambda z: f(z, x, *args), z0, threshold=f_thres)['result']

where we note that the forward pass does not need to store any intermediate state, so we put it in a torch.no_grad() block.

Step 3: Engage with the autodiff tape to use implicit differentiation

Finally, we need to ensure there is a way to compute the backward pass of a DEQ, which relies on implicit function theorem. To do this, we can use the register_hook function in PyTorch that registers a backward hook function to be executed in the backward pass. As we noted in the paper, the backward pass is simply solving for the fixed point of a linear system involving the Jacobian at the equilibrium:

new_z_star = self.f(z_star.requires_grad_(), x, *args)

def backward_hook(grad):
    if self.hook is not None:
        self.hook.remove()
        torch.cuda.synchronize()   # To avoid infinite recursion
    # Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star
    new_grad = self.solver(lambda y: autograd.grad(new_z_star, z_star, y, retain_graph=True)[0] + grad, \
                           torch.zeros_like(grad), threshold=b_thres)['result']
    return new_grad

self.hook = new_z_star.register_hook(backward_hook)

(Optional) Additional Step: Jacobian Regularization.

The fixed-point formulation of DEQ models means their stability are directly characterized by the Jacobian matrix J_f at the equilibrium point. Therefore, we provide code for analyzing and regularizing the Jacobian properties (based on the ICML'21 paper Stabilizing Equilibrium Models by Jacobian Regularization). Specifically, we added the following flags to the training script:

  • jac_loss_weight: The strength of Jacobian regularization, where we regularize ||J_f||_F.
  • jac_loss_freq: The frequency p of the stochastic Jacobian regularization (i.e., we only apply this loss with probaility p during training).
  • jac_incremental: If >0, then we increase the jac_loss_weight by 0.1 after every jac_incremental training steps.
  • spectral_radius_mode: If True, estimate the DEQ models' spectral radius when evaluating on the validation set.

A full DEQ model implementation is therefore as simple as follows:

from lib.solvers import anderson, broyden
from lib.jacobian import jac_loss_estimate

class DEQModel(nn.Module):
    def __init__(self, ...):
        ...
        self.f = Layer(...)
        self.solver = broyden
        ...
    
    def forward(self, x, ..., **kwargs):
        z0 = torch.zeros(...)

        # Forward pass
        with torch.no_grad():
            z_star = self.solver(lambda z: self.f(z, x, *args), z0, threshold=f_thres)['result']   # See step 2 above
            new_z_star = z_star

        # (Prepare for) Backward pass, see step 3 above
        if self.training:
            new_z_star = self.f(z_star.requires_grad_(), x, *args)
            
            # Jacobian-related computations, see additional step above. For instance:
            jac_loss = jac_loss_estimate(new_z_star, z_star, vecs=1)

            def backward_hook(grad):
                if self.hook is not None:
                    self.hook.remove()
                    torch.cuda.synchronize()   # To avoid infinite recursion
                # Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star
                new_grad = self.solver(lambda y: autograd.grad(new_z_star, z_star, y, retain_graph=True)[0] + grad, \
                                       torch.zeros_like(grad), threshold=b_thres)['result']
                return new_grad

            self.hook = new_z_star.register_hook(backward_hook)
        return new_z_star, ...

Fixed-point Solvers

We provide PyTorch implementation of two generic solvers, broyden(...) (based on Broyden's method) and anderson(...) (based on Anderson acceleration) in lib/solvers.py. Both functions take in the transformation f whose fixed point we would like to solve for, and returns a dictionary of the following format:

{
 "result": ... (The closest estimate to the fixed point),
 "nstep": ... (The step that gives us this closest estimate),
 "abs_trace": ... (Absolute residuals along the trajectory),
 "rel_trace": ... (Relative residuals along the trajectory),
 ...
}

Pretrained Models

See DEQ-Sequence/ and MDEQ-Vision/ sub-directories for the links.

Credits

  • The transformer implementation as well as the extra modules (e.g., adaptive embeddings) were based on the Transformer-XL repo.

  • Some utilization code (e.g., model summary and yaml processing) of this repo were modified from the HRNet repo.

  • We also added the RAdam optimizer as an option to the training (but didn't set it to default). The RAdam implementation is from the RAdam repo.

Bibtex

If you find this repository useful for your research, please consider citing our work(s):

  1. Deep Equilibrium Models
@inproceedings{bai2019deep,
  author    = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
  title     = {Deep Equilibrium Models},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year      = {2019},
}
  1. Multiscale Deep Equilibrium Models
@inproceedings{bai2020multiscale,
  author    = {Shaojie Bai and Vladlen Koltun and J. Zico Kolter},
  title     = {Multiscale Deep Equilibrium Models},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year      = {2020},
}
  1. Stabilizing Equilibrium Models by Jacobian Regularization
@inproceedings{bai2021stabilizing,
  title     = {Stabilizing Equilibrium Models by Jacobian Regularization},
  author    = {Shaojie Bai and Vladlen Koltun and J. Zico Kolter},
  booktitle = {International Conference on Machine Learning (ICML)},
  year      = {2021}
}

More Repositories

1

TCN

Sequence modeling benchmarks and temporal convolutional networks
Python
4,122
star
2

convmixer

Implementation of ConvMixer for "Patches Are All You Need? 🤷"
Python
1,059
star
3

mpc.pytorch

A fast and differentiable model predictive control (MPC) solver for PyTorch.
Python
865
star
4

qpth

A fast and differentiable QP solver for PyTorch.
Python
673
star
5

wanda

A simple and effective LLM pruning approach.
Python
602
star
6

optnet

OptNet: Differentiable Optimization as a Layer in Neural Networks
Python
507
star
7

trellisnet

[ICLR'19] Trellis Networks for Sequence Modeling
Python
473
star
8

fast_adversarial

[ICLR 2020] A repository for extremely fast adversarial training using FGSM
Python
422
star
9

SATNet

Bridging deep learning and logical reasoning using a differentiable satisfiability solver.
Python
404
star
10

convex_adversarial

A method for training neural networks that are provably robust to adversarial attacks.
Python
378
star
11

smoothing

Provable adversarial robustness at ImageNet scale
Python
357
star
12

pytorch_fft

PyTorch wrapper for FFTs
Python
313
star
13

lcp-physics

A differentiable LCP physics engine in PyTorch.
Python
292
star
14

icnn

Input Convex Neural Networks
Python
274
star
15

differentiable-mpc

Python
239
star
16

mdeq

[NeurIPS'20] Multiscale Deep Equilibrium Models
Python
232
star
17

e2e-model-learning

Task-based end-to-end model learning in stochastic optimization
Python
195
star
18

ect

Consistency Models Made Easy
Python
188
star
19

deq-flow

[CVPR 2022] Deep Equilibrium Optical Flow Estimation
Python
177
star
20

robust_overfitting

Python
153
star
21

DC3

DC3: A Learning Method for Optimization with Hard Constraints
Python
133
star
22

cfd-gcn

Python
113
star
23

massive-activations

Code accompanying the paper "Massive Activations in Large Language Models"
Python
95
star
24

tofu

Landing Page for TOFU
Python
86
star
25

FLYP

Code for Finetune like you pretrain: Improved finetuning of zero-shot vision models
Python
85
star
26

projected_sinkhorn

Python
85
star
27

torchdeq

Modern Fixed Point Systems using Pytorch
Python
74
star
28

perturbation_learning

Learning perturbation sets for robust machine learning
Python
64
star
29

scaling_laws_data_filtering

Python
59
star
30

lml

The Limited Multi-Label Projection Layer
Python
58
star
31

deq-ddim

Python
58
star
32

chatllm-vscode

TypeScript
58
star
33

edge-of-stability

Python
55
star
34

robust-nn-control

Enforcing robust control guarantees within neural network policies
Python
52
star
35

monotone_op_net

Monotone operator equilibrium networks
Jupyter Notebook
51
star
36

orthogonal-convolutions

Implementations of orthogonal and semi-orthogonal convolutions in the Fourier domain with applications to adversarial robustness
Jupyter Notebook
41
star
37

convmixer-cifar10

Simple CIFAR-10 classification with ConvMixer
Python
40
star
38

newton_admm

A Newton ADMM based solver for Cone programming.
Python
38
star
39

tta_conjugate

Test-Time Adaptation via Conjugate Pseudo-Labels
Python
36
star
40

T-MARS

Code for T-MARS data filtering
Python
34
star
41

stable_dynamics

Companion code to "Learning Stable Deep Dynamics Models" (Manek and Kolter, 2019)
Jupyter Notebook
31
star
42

ImpSq

Implicit^2: Implicit model for implicit neural representations
Python
27
star
43

robust_union

[ICML'20] Multi Steepest Descent (MSD) for robustness against the union of multiple perturbation models.
Python
25
star
44

breaking-poisoned-classifier

Code for paper "Poisoned classifiers are not only backdoored, they are fundamentally broken"
Jupyter Notebook
24
star
45

diffusion-model-hallucination

Python
24
star
46

acr-memorization

Python
24
star
47

gradient_regularized_gan

Code for "Gradient descent GAN optimization is locally stable"
Python
21
star
48

get

Generative Equilibrium Transformer
Python
17
star
49

smoothinv

Single Image Backdoor Inversion via Robust Smoothed Classifiers
Python
16
star
50

intermediate_robustness

Python
16
star
51

mixing

The Mixing method: coordinate descent for low-rank semidefinite programming
C
15
star
52

dreaml

dreaml: dynamic reactive machine learning
JavaScript
12
star
53

ase

Analogous Safe-state Exploration (ASE) is an algorithm for provably safe and optimal exploration in MDPs with unknown, stochastic dynamics.
Python
11
star
54

sdp_clustering

Jupyter Notebook
11
star
55

JIIO-DEQ

Efficient joint input optimization and inference with DEQ
Python
10
star
56

uniform-convergence-NeurIPS19

The code for the NeurIPS19 paper and blog on "Uniform convergence may be unable to explain generalization in deep learning".
Jupyter Notebook
10
star
57

sdp_mrf

Jupyter Notebook
3
star
58

mixsat

Low-rank semidefinite programming for the MAX2SAT problem
C
3
star
59

MonotoneDBM

Python
2
star
60

lipschitz_mondeq

Jupyter Notebook
1
star
61

mugrade

Python
1
star