• Stars
    star
    232
  • Rank 172,847 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated over 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[NeurIPS'20] Multiscale Deep Equilibrium Models

Multiscale Deep Equilibrium Models

πŸ’₯πŸ’₯πŸ’₯πŸ’₯

This repo is deprecated and we will soon stop actively maintaining it, as a more up-to-date (and simpler & more efficient) implementation of MDEQ with the same set of tasks as here is now available in the DEQ repo.

We STRONGLY recommend using with the MDEQ-Vision code in the DEQ repo (which also supports Jacobian-related analysis).

πŸ’₯πŸ’₯πŸ’₯πŸ’₯


This repository contains the code for the multiscale deep equilibrium (MDEQ) model proposed in the paper Multiscale Deep Equilibrium Models by Shaojie Bai, Vladlen Koltun and J. Zico Kolter.

Is implicit deep learning relevant for general, large-scale pattern recognition tasks? We propose the multiscale deep equilibrium (MDEQ) model, which expands upon the DEQ formulation substantially to introduce simultaneous equilibrium modeling of multiple signal resolutions. Specifically, MDEQ solves for and backpropagates through synchronized equilibria of multiple feature representation streams. Such structure rectifies one of the major drawbacks of DEQ, and provide natural hierarchical interfaces for auxiliary losses and compound training procedures (e.g., pretraining and finetuning). Our experiment demonstrate for the first time that "shallow" implicit models can scale to and achieve near-SOTA results on practical computer vision tasks (e.g., megapixel images on Cityscapes segmentation).

We provide in this repo the implementation and the links to the pretrained classification & segmentation MDEQ models.

If you find thie repository useful for your research, please consider citing our work:

@inproceedings{bai2020multiscale,
    author    = {Shaojie Bai and Vladlen Koltun and J. Zico Kolter},
    title     = {Multiscale Deep Equilibrium Models},
    booktitle   = {Advances in Neural Information Processing Systems (NeurIPS)},
    year      = {2020},
}

Overview

The structure of a multiscale deep equilibrium model (MDEQ) is shown below. All components of the model are shown in this figure (in practice, we use n=4).

Examples

Some examples of MDEQ segmentation results on the Cityscapes dataset.

Requirements

PyTorch >=1.4.0, torchvision >= 0.4.0

Datasets

  • CIFAR-10: We download the CIFAR-10 dataset using PyTorch's torchvision package (included in this repo).
  • ImageNet We follow the implementation from the PyTorch ImageNet Training repo.
  • Cityscapes: We download the Cityscapes dataset from its official website and process it according to this repo. Cityscapes dataset additionally require a list folder that aligns each original image with its corresponding labeled segmented image. This list folder can be downloaded here.

All datasets should be downloaded, processed and put in the respective data/[DATASET_NAME] directory. The data/ directory should look like the following:

data/
  cityscapes/
  imagenet/
  ...          (other datasets)
  list/        (see above)

Usage

All experiment settings are provided in the .yaml files under the experiments/ folder.

To train an MDEQ classification model on ImageNet/CIFAR-10, do

python tools/cls_train.py --cfg experiments/[DATASET_NAME]/[CONFIG_FILE_NAME].yaml

To train an MDEQ segmentation model on Cityscapes, do

python -m torch.distributed.launch --nproc_per_node=4 tools/seg_train.py --cfg experiments/[DATASET_NAME]/[CONFIG_FILE_NAME].yaml

where you should provide the pretrained ImageNet model path in the corresponding configuration (.yaml) file. We provide a sample pretrained model extractor in pretrained_models/, but you can also write your own script.

Similarly, to test the model and generate segmentation results on Cityscapes, do

python tools/seg_test.py --cfg experiments/[DATASET_NAME]/[CONFIG_FILE_NAME].yaml

You can (and probably should) initiate the Cityscapes training with an ImageNet-pretrained MDEQ. You need to extract the state dict from the ImageNet checkpointed model, and set the MODEL.PRETRAINED entry in Cityscapes yaml file to this state dict on disk.

The model implementation and MDEQ's algorithmic components (e.g., L-Broyden's method) can be found in lib/.

Pre-trained Models

We provide some reasonably good pre-trained weights here so that one can quickly play with DEQs without training from scratch.

Description Task Dataset Model
MDEQ-XL ImageNet Classification ImageNet download (.pkl)
MDEQ-XL Cityscapes(val) Segmentation Cityscapes download (.pkl)
MDEQ-Small ImageNet Classification ImageNet download (.pkl)
MDEQ-Small Cityscapes(val) Segmentation Cityscapes download (.pkl)

I. Example of how to evaluate the pretrained ImageNet model:

  1. Download the pretrained ImageNet .pkl file. (I recommend using the gdown command!)
  2. Put the model under pretrained_models/ folder with some file name [FILENAME].
  3. Run the MDEQ classification validation command:
python tools/cls_valid.py --testModel pretrained_models/[FILENAME] --cfg experiments/imagenet/cls_mdeq_[SIZE].yaml

For example, for MDEQ-Small, you should get >75% top-1 accuracy.

II. Example of how to use the pretrained ImageNet model to train on Cityscapes:

  1. Download the pretrained ImageNet .pkl file.
  2. Put the model under pretrained_models/ folder with some file name [FILENAME].
  3. In the corresponding experiments/cityscapes/seg_MDEQ_[SIZE].yaml (where SIZE is typically SMALL, LARGE or XL), set MODEL.PRETRAINED to "pretrained_models/[FILENAME]".
  4. Run the MDEQ segmentation training command (see the "Usage" section above):
python -m torch.distributed.launch --nproc_per_node=[N_GPUS] tools/seg_train.py --cfg experiments/cityscapes/seg_MDEQ_[SIZE].yaml

III. Example of how to use the pretrained Cityscapes model for inference:

  1. Download the pretrained Cityscapes .pkl file
  2. Put the model under pretrained_models/ folder with some file name [FILENAME].
  3. In the corresponding experiments/cityscapes/seg_MDEQ_[SIZE].yaml (where SIZE is typically SMALL, LARGE or XL), set TEST.MODEL_FILE to "pretrained_models/[FILENAME]".
  4. Run the MDEQ segmentation testing command (see the "Usage" section above):
python tools/seg_test.py --cfg experiments/cityscapes/seg_MDEQ_[SIZE].yaml

Tips:

  • To load the Cityscapes pretrained model, download the .pkl file and specify the path in config.[TRAIN/TEST].MODEL_FILE (which is '' by default) in the .yaml files. This is different from setting MODEL.PRETRAINED, see the point below.
  • The difference between [TRAIN/TEST].MODEL_FILE and MODEL.PRETRAINED arguments in the yaml files: the former is used to load all of the model parameters; the latter is for compound training (e.g., when transferring from ImageNet to Cityscapes, we want to discard the final classifier FC layers).
  • The repo supports checkpointing of models at each epoch. One can resume from a previously saved checkpoint by turning on the TRAIN.RESUME argument in the yaml files.
  • Just like DEQs, the MDEQ models can be slower than explicit deep networks, and even more so as the image size increases (because larger images typically require more Broyden iterations to converge well; see Figure 5 in the paper). But one can play with the forward and backward thresholds to adjust the runtime.

Acknowledgement

Some utilization code (e.g., model summary and yaml processing) of this repo were modified from the HRNet repo and the DEQ repo.

More Repositories

1

TCN

Sequence modeling benchmarks and temporal convolutional networks
Python
4,122
star
2

convmixer

Implementation of ConvMixer for "Patches Are All You Need? 🀷"
Python
1,059
star
3

mpc.pytorch

A fast and differentiable model predictive control (MPC) solver for PyTorch.
Python
865
star
4

deq

[NeurIPS'19] Deep Equilibrium Models
Python
719
star
5

qpth

A fast and differentiable QP solver for PyTorch.
Python
673
star
6

wanda

A simple and effective LLM pruning approach.
Python
602
star
7

optnet

OptNet: Differentiable Optimization as a Layer in Neural Networks
Python
507
star
8

trellisnet

[ICLR'19] Trellis Networks for Sequence Modeling
Python
473
star
9

fast_adversarial

[ICLR 2020] A repository for extremely fast adversarial training using FGSM
Python
422
star
10

SATNet

Bridging deep learning and logical reasoning using a differentiable satisfiability solver.
Python
404
star
11

convex_adversarial

A method for training neural networks that are provably robust to adversarial attacks.
Python
378
star
12

smoothing

Provable adversarial robustness at ImageNet scale
Python
357
star
13

pytorch_fft

PyTorch wrapper for FFTs
Python
313
star
14

lcp-physics

A differentiable LCP physics engine in PyTorch.
Python
292
star
15

icnn

Input Convex Neural Networks
Python
274
star
16

differentiable-mpc

Python
239
star
17

e2e-model-learning

Task-based end-to-end model learning in stochastic optimization
Python
195
star
18

ect

Consistency Models Made Easy
Python
188
star
19

deq-flow

[CVPR 2022] Deep Equilibrium Optical Flow Estimation
Python
177
star
20

robust_overfitting

Python
153
star
21

DC3

DC3: A Learning Method for Optimization with Hard Constraints
Python
133
star
22

cfd-gcn

Python
113
star
23

massive-activations

Code accompanying the paper "Massive Activations in Large Language Models"
Python
95
star
24

tofu

Landing Page for TOFU
Python
86
star
25

FLYP

Code for Finetune like you pretrain: Improved finetuning of zero-shot vision models
Python
85
star
26

projected_sinkhorn

Python
85
star
27

torchdeq

Modern Fixed Point Systems using Pytorch
Python
74
star
28

perturbation_learning

Learning perturbation sets for robust machine learning
Python
64
star
29

scaling_laws_data_filtering

Python
59
star
30

lml

The Limited Multi-Label Projection Layer
Python
58
star
31

deq-ddim

Python
58
star
32

chatllm-vscode

TypeScript
58
star
33

edge-of-stability

Python
55
star
34

robust-nn-control

Enforcing robust control guarantees within neural network policies
Python
52
star
35

monotone_op_net

Monotone operator equilibrium networks
Jupyter Notebook
51
star
36

orthogonal-convolutions

Implementations of orthogonal and semi-orthogonal convolutions in the Fourier domain with applications to adversarial robustness
Jupyter Notebook
41
star
37

convmixer-cifar10

Simple CIFAR-10 classification with ConvMixer
Python
40
star
38

newton_admm

A Newton ADMM based solver for Cone programming.
Python
38
star
39

tta_conjugate

Test-Time Adaptation via Conjugate Pseudo-Labels
Python
36
star
40

T-MARS

Code for T-MARS data filtering
Python
34
star
41

stable_dynamics

Companion code to "Learning Stable Deep Dynamics Models" (Manek and Kolter, 2019)
Jupyter Notebook
31
star
42

ImpSq

Implicit^2: Implicit model for implicit neural representations
Python
27
star
43

robust_union

[ICML'20] Multi Steepest Descent (MSD) for robustness against the union of multiple perturbation models.
Python
25
star
44

breaking-poisoned-classifier

Code for paper "Poisoned classifiers are not only backdoored, they are fundamentally broken"
Jupyter Notebook
24
star
45

diffusion-model-hallucination

Python
24
star
46

acr-memorization

Python
24
star
47

gradient_regularized_gan

Code for "Gradient descent GAN optimization is locally stable"
Python
21
star
48

get

Generative Equilibrium Transformer
Python
17
star
49

smoothinv

Single Image Backdoor Inversion via Robust Smoothed Classifiers
Python
16
star
50

intermediate_robustness

Python
16
star
51

mixing

The Mixing method: coordinate descent for low-rank semidefinite programming
C
15
star
52

dreaml

dreaml: dynamic reactive machine learning
JavaScript
12
star
53

ase

Analogous Safe-state Exploration (ASE) is an algorithm for provably safe and optimal exploration in MDPs with unknown, stochastic dynamics.
Python
11
star
54

sdp_clustering

Jupyter Notebook
11
star
55

JIIO-DEQ

Efficient joint input optimization and inference with DEQ
Python
10
star
56

uniform-convergence-NeurIPS19

The code for the NeurIPS19 paper and blog on "Uniform convergence may be unable to explain generalization in deep learning".
Jupyter Notebook
10
star
57

sdp_mrf

Jupyter Notebook
3
star
58

mixsat

Low-rank semidefinite programming for the MAX2SAT problem
C
3
star
59

MonotoneDBM

Python
2
star
60

lipschitz_mondeq

Jupyter Notebook
1
star
61

mugrade

Python
1
star