• Stars
    star
    177
  • Rank 215,985 (Top 5 %)
  • Language
    Python
  • License
    GNU Affero Genera...
  • Created over 2 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[CVPR 2022] Deep Equilibrium Optical Flow Estimation

Deep Equilibrium Optical Flow Estimation

PWC

(๐ŸŒŸVersion 2.0 released!๐ŸŒŸ)

This is the official repo for the paper Deep Equilibrium Optical Flow Estimation (CVPR 2022), by Shaojie Bai*, Zhengyang Geng*, Yash Savani and J. Zico Kolter.

A deep equilibrium (DEQ) flow estimator directly models the flow as a path-independent, โ€œinfinite-levelโ€ fixed-point solving process. We propose to use this implicit framework to replace the existing recurrent approach to optical flow estimation. The DEQ flows converge faster, require less memory, are often more accurate, and are compatible with prior model designs (e.g., RAFT and GMA).

Demo

We provide a demo video of the DEQ flow results below.

demo.mp4

Update

๐ŸŒŸ 2022.xx.xx - Support visualization and demo on your own datasets and videos! Coming soon!

๐ŸŒŸ 2022.08.08 - Release the version 2.0 of DEQ-Flow! DEQ-Flow will be merged into DEQ after further upgrading and unit testing.

  • A clean and decoupled DEQ lib. This is a fully featured and out-of-the-box lib. You're welcome to implement your own DEQ using our DEQ lib! We support the following features. (The DEQ lib will be available on PyPI soon for easy installation via pip.)

    • Automatic arg parser decorator. You can call this function to add the DEQ args to your program. See the explanation for args here!

      add_deq_args(parser)
    • Automatic DEQ definition. Call get_deq to get your DEQ class! It's highly decoupled implementation agnostic to your model design!

      DEQ = get_deq(args)
      self.deq = DEQ(args)
    • Automatic normalization for DEQ. You now do not need to add normalization manually to each weight in the DEQ func!

      if args.wnorm:
          apply_weight_norm(self.update_block)
    • Easy DEQ forward. Even for a multi-equilibria system, you can call the DEQ function using several lines!

      # Assume args is a list [z1, z2, ..., zn] 
      # of to-be-solved equilibrium variables.
      def func(*args):
          # A functor defined in the Pytorch forward function.
          # Having the same input and output tensor shapes.
          return args
      
      deq_func = DEQWrapper(func, args)
      z_init = deq_func.list2vec(*args) # will be merged into self.deq(...)
      z_out, info = self.deq(deq_func, z_init)
    • Automatic DEQ training. Gradients (both exact and inexact grad) are tracked automatically! Fixed point correction can be customized through your arg parser. Just post-process z_out as you want!

  • Benchmarked results and checkpoints. Using the release code base v.2.0, we've trained DEQ-Flow-H on FlyingChairs and FlyingThings for two schedules, 120k+120k (1x) and 120k+360k (3x). This implementation demonstrated a new SOTA, surpassing our previous results in performance, training speed, and memory usage.

    Notably, we also benchmark RAFT using the same model size. DEQ-Flow demonstrates a clear performance and efficiency margin and much stronger scaling property (scale up to larger models) over RAFT!

    Checkpoint Name Sintel (clean) Sintel (final) KITTI AEPE KITTI F1-all
    RAFT-H-1x 1.36 2.59 4.47 16.16
    DEQ-Flow-H-1x 1.27 2.58 3.76 12.95
    DEQ-Flow-H-3x 1.27 2.48 3.77 13.41
    • 1x=120k iterations on FlyingThings, 3x=360k iterations on FlyingThings, using a batch size of 6.
    • Increasing the batch size on FlyingThings can further improve these results, e.g., a batch size of 12 can reduce the F1-all of DEQ-Flow-H-1x to around 12.5 on KITTI.

    To validate our results, download the pretrained checkpoints into the checkpoints directory. Run the following command in code.v.2.0 to infer over the Sintel train set and the KITTI train set. This is a reference log.

    bash val.sh

Requirements

The code in this repo has been tested on PyTorch v1.10.0. Install required environments through the following commands.

conda create --name deq python==3.6.10
conda activate deq
conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge
conda install tensorboard scipy opencv matplotlib einops termcolor -c conda-forge

Download the following datasets into the datasets directory.


The following README doc is for version 1.0, i.e., code.v.1.0. You can follow this to reproduce all the results.

Inference

Download the pretrained checkpoints into the checkpoints directory. Run the following command to infer over the Sintel train set and the KITTI train set.

bash val.sh

You may expect the following performance statistics of given checkpoints. This is a reference log.

Checkpoint Name Sintel (clean) Sintel (final) KITTI AEPE KITTI F1-all
DEQ-Flow-B 1.43 2.79 5.43 16.67
DEQ-Flow-H-1 1.45 2.58 3.97 13.41
DEQ-Flow-H-2 1.37 2.62 3.97 13.62
DEQ-Flow-H-3 1.36 2.62 4.02 13.92

Visualization

Download the pretrained checkpoints into the checkpoints directory. Run the following command to visualize the optical flow estimation over the KITTI test set.

bash viz.sh

Training

Download FlyingChairs-pretrained checkpoints into the checkpoints directory.

For the efficiency mode, you can run 1-step gradient to train DEQ-Flow-B via the following command. Memory overhead per GPU is about 5800 MB.

You may expect best results of about 1.46 (AEPE) on Sintel (clean), 2.85 (AEPE) on Sintel (final), 5.29 (AEPE) and 16.24 (F1-all) on KITTI. This is a reference log.

bash train_B_demo.sh

For training a demo of DEQ-Flow-H, you can run this command. Memory overhead per GPU is about 6300 MB. It can be further reduced to about 4200 MB per GPU when combined with --mixed-precision. You can further reduce the memory cost if you employ the CUDA implementation of cost volumn by RAFT.

You may expect best results of about 1.41 (AEPE) on Sintel (clean), 2.76 (AEPE) on Sintel (final), 4.44 (AEPE) and 14.81 (F1-all) on KITTI. This is a reference log.

bash train_H_demo.sh

To train DEQ-Flow-B on Chairs and Things, use the following command.

bash train_B.sh

For the performance mode, you can run this command to train DEQ-Flow-H using the C+T and C+T+S+K+H schedule. You may expect the performance of <1.40 (AEPE) on Sintel (clean), around 2.60 (AEPE) on Sintel (final), around 4.00 (AEPE) and 13.6 (F1-all) on KITTI. DEQ-Flow-H-1,2,3 are checkpoints from three runs.

Currently, this training protocol could entail resources slightly more than two 11 GB GPUs. In the near future, we will upload an implementation revision (of the DEQ models) that shall further reduce this overhead to less than two 11 GB GPUs.

bash train_H_full.sh

Code Usage

Under construction. We will provide more detailed instructions on the code usage (e.g., argparse flags, fixed-point solvers, backward IFT modes) in the coming days.


A Tutorial on DEQ

If you hope to learn more about DEQ models, here is an official NeurIPS tutorial on implicit deep learning. Enjoy yourself!

Reference

If you find our work helpful to your research, please consider citing this paper. :)

@inproceedings{deq-flow,
    author = {Bai, Shaojie and Geng, Zhengyang and Savani, Yash and Kolter, J. Zico},
    title = {Deep Equilibrium Optical Flow Estimation},
    booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
    year = {2022}
}

Credit

A lot of the utility code in this repo were adapted from the RAFT repo and the DEQ repo.

Contact

Feel free to contact us if you have additional questions. Please drop an email through [email protected] (or Twitter).

More Repositories

1

TCN

Sequence modeling benchmarks and temporal convolutional networks
Python
4,122
star
2

convmixer

Implementation of ConvMixer for "Patches Are All You Need? ๐Ÿคท"
Python
1,059
star
3

mpc.pytorch

A fast and differentiable model predictive control (MPC) solver for PyTorch.
Python
865
star
4

deq

[NeurIPS'19] Deep Equilibrium Models
Python
719
star
5

qpth

A fast and differentiable QP solver for PyTorch.
Python
673
star
6

wanda

A simple and effective LLM pruning approach.
Python
602
star
7

optnet

OptNet: Differentiable Optimization as a Layer in Neural Networks
Python
507
star
8

trellisnet

[ICLR'19] Trellis Networks for Sequence Modeling
Python
473
star
9

fast_adversarial

[ICLR 2020] A repository for extremely fast adversarial training using FGSM
Python
422
star
10

SATNet

Bridging deep learning and logical reasoning using a differentiable satisfiability solver.
Python
404
star
11

convex_adversarial

A method for training neural networks that are provably robust to adversarial attacks.
Python
378
star
12

smoothing

Provable adversarial robustness at ImageNet scale
Python
357
star
13

pytorch_fft

PyTorch wrapper for FFTs
Python
313
star
14

lcp-physics

A differentiable LCP physics engine in PyTorch.
Python
292
star
15

icnn

Input Convex Neural Networks
Python
274
star
16

differentiable-mpc

Python
239
star
17

mdeq

[NeurIPS'20] Multiscale Deep Equilibrium Models
Python
232
star
18

e2e-model-learning

Task-based end-to-end model learning in stochastic optimization
Python
195
star
19

ect

Consistency Models Made Easy
Python
188
star
20

robust_overfitting

Python
153
star
21

DC3

DC3: A Learning Method for Optimization with Hard Constraints
Python
133
star
22

cfd-gcn

Python
113
star
23

massive-activations

Code accompanying the paper "Massive Activations in Large Language Models"
Python
95
star
24

tofu

Landing Page for TOFU
Python
86
star
25

FLYP

Code for Finetune like you pretrain: Improved finetuning of zero-shot vision models
Python
85
star
26

projected_sinkhorn

Python
85
star
27

torchdeq

Modern Fixed Point Systems using Pytorch
Python
74
star
28

perturbation_learning

Learning perturbation sets for robust machine learning
Python
64
star
29

scaling_laws_data_filtering

Python
59
star
30

lml

The Limited Multi-Label Projection Layer
Python
58
star
31

deq-ddim

Python
58
star
32

chatllm-vscode

TypeScript
58
star
33

edge-of-stability

Python
55
star
34

robust-nn-control

Enforcing robust control guarantees within neural network policies
Python
52
star
35

monotone_op_net

Monotone operator equilibrium networks
Jupyter Notebook
51
star
36

orthogonal-convolutions

Implementations of orthogonal and semi-orthogonal convolutions in the Fourier domain with applications to adversarial robustness
Jupyter Notebook
41
star
37

convmixer-cifar10

Simple CIFAR-10 classification with ConvMixer
Python
40
star
38

newton_admm

A Newton ADMM based solver for Cone programming.
Python
38
star
39

tta_conjugate

Test-Time Adaptation via Conjugate Pseudo-Labels
Python
36
star
40

T-MARS

Code for T-MARS data filtering
Python
34
star
41

stable_dynamics

Companion code to "Learning Stable Deep Dynamics Models" (Manek and Kolter, 2019)
Jupyter Notebook
31
star
42

ImpSq

Implicit^2: Implicit model for implicit neural representations
Python
27
star
43

robust_union

[ICML'20] Multi Steepest Descent (MSD) for robustness against the union of multiple perturbation models.
Python
25
star
44

breaking-poisoned-classifier

Code for paper "Poisoned classifiers are not only backdoored, they are fundamentally broken"
Jupyter Notebook
24
star
45

diffusion-model-hallucination

Python
24
star
46

acr-memorization

Python
24
star
47

gradient_regularized_gan

Code for "Gradient descent GAN optimization is locally stable"
Python
21
star
48

get

Generative Equilibrium Transformer
Python
17
star
49

smoothinv

Single Image Backdoor Inversion via Robust Smoothed Classifiers
Python
16
star
50

intermediate_robustness

Python
16
star
51

mixing

The Mixing method: coordinate descent for low-rank semidefinite programming
C
15
star
52

dreaml

dreaml: dynamic reactive machine learning
JavaScript
12
star
53

ase

Analogous Safe-state Exploration (ASE) is an algorithm for provably safe and optimal exploration in MDPs with unknown, stochastic dynamics.
Python
11
star
54

sdp_clustering

Jupyter Notebook
11
star
55

JIIO-DEQ

Efficient joint input optimization and inference with DEQ
Python
10
star
56

uniform-convergence-NeurIPS19

The code for the NeurIPS19 paper and blog on "Uniform convergence may be unable to explain generalization in deep learning".
Jupyter Notebook
10
star
57

sdp_mrf

Jupyter Notebook
3
star
58

mixsat

Low-rank semidefinite programming for the MAX2SAT problem
C
3
star
59

MonotoneDBM

Python
2
star
60

lipschitz_mondeq

Jupyter Notebook
1
star
61

mugrade

Python
1
star