• Stars
    star
    233
  • Rank 172,230 (Top 4 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 2 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "Fast Vision Transformers with HiLo Attention"

Fast Vision Transformers with HiLo AttentionπŸ‘‹(NeurIPS 2022 Spotlight)

License PyTorch

This is the official PyTorch implementation of Fast Vision Transformers with HiLo Attention.

By Zizheng Pan, Jianfei Cai, and Bohan Zhuang.

News

  • 20/04/2023. Update training scripts with PyTorch 2.0. Support ONNX and TensorRT model conversion, see here.

  • 15/12/2022. Releasing ImageNet pretrained weights of using different values of alpha.

  • 11/11/2022. LITv2 will be presented as Spotlight!

  • 13/10/2022. Update code for higher version of timm. Compatible with PyTorch 1.12.1 + CUDA 11.3 + timm 0.6.11.

  • 30/09/2022. Add benchmarking results for single attention layer. HiLo is super fast on both CPU and GPU!

  • 15/09/2022. LITv2 is accepted by NeurIPS 2022! πŸ”₯πŸ”₯πŸ”₯

  • 16/06/2022. We release the source code for classification/detection/segmentation, along with the pretrained weights. Any issues are welcomed!

A Gentle Introduction

hilo

We introduce LITv2, a simple and effective ViT which performs favourably against the existing state-of-the-art methods across a spectrum of different model sizes with faster speed.

hilo

The core of LITv2: HiLo attention HiLo is inspired by the insight that high frequencies in an image capture local fine details and low frequencies focus on global structures, whereas a multi-head self-attention layer neglects the characteristic of different frequencies. Therefore, we propose to disentangle the high/low frequency patterns in an attention layer by separating the heads into two groups, where one group encodes high frequencies via self-attention within each local window, and another group performs the attention to model the global relationship between the average-pooled low-frequency keys from each window and each query position in the input feature map.

A Simple Demo

To quickly understand HiLo attention, you only need to install PyTorch and try the following code in the root directory of this repo.

from hilo import HiLo
import torch

model = HiLo(dim=384, num_heads=12, window_size=2, alpha=0.5)

x = torch.randn(64, 196, 384) # batch_size x num_tokens x hidden_dimension
out = model(x, 14, 14)
print(out.shape)
print(model.flops(14, 14)) # the numeber of FLOPs

Output:

torch.Size([64, 196, 384])
83467776

Installation

Requirements

  • Linux with Python β‰₯ 3.6
  • PyTorch >= 1.8.1
  • timm >= 0.3.2
  • CUDA 11.1
  • An NVIDIA GPU

Conda environment setup

Note: You can use the same environment to debug LITv1. Otherwise, you can create a new python virtual environment by the following script.

conda create -n lit python=3.7
conda activate lit

# Install Pytorch and TorchVision
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

pip install timm
pip install ninja
pip install tensorboard

# Install NVIDIA apex
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ../
rm -rf apex/

# Build Deformable Convolution
cd mm_modules/DCN
python setup.py build install

pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8

Getting Started

For image classification on ImageNet, please refer to classification.

For object detection on COCO 2017, please refer to detection.

For semantic segmentation on ADE20K, please refer to segmentation.

Results and Model Zoo

Note: For your convenience, you can find all models and logs from Google Drive (4.8G in total). Alternatively, we also provide download links from github.

Image Classification on ImageNet-1K

All models are trained with 300 epochs with a total batch size of 1024 on 8 V100 GPUs.

Model Resolution Params (M) FLOPs (G) Throughput (imgs/s) Train Mem (GB) Test Mem (GB) Top-1 (%) Download
LITv2-S 224 28 3.7 1,471 5.1 1.2 82.0 model & log
LITv2-M 224 49 7.5 812 8.8 1.4 83.3 model & log
LITv2-B 224 87 13.2 602 12.2 2.1 83.6 model & log
LITv2-B 384 87 39.7 198 35.8 4.6 84.7 model

By default, the throughput and memory footprint are tested on one RTX 3090 based on a batch size of 64. Memory is measured by the peak memory usage with torch.cuda.max_memory_allocated(). Throughput is averaged over 30 runs.

Pretrained LITv2-S with Different Values of Alpha

Alpha Params (M) Lo-Fi Heads Hi-Fi Heads FLOPs (G) ImageNet Top1 (%) Download
0.0 28 0 12 3.97 81.16 github
0.2 28 2 10 3.88 81.89 github
0.4 28 4 8 3.82 81.81 github
0.5 28 6 6 3.77 81.88 github
0.7 28 8 4 3.74 81.94 github
0.9 28 10 2 3.73 82.03 github
1.0 28 12 0 3.70 81.89 github

Pretrained weights from the experiments of Figure 4: Effect of Ξ± based on LITv2-S.

Object Detection on COCO 2017

All models are trained with 1x schedule (12 epochs) with a total batch size of 16 on 8 V100 GPUs.

RetinaNet

Backbone Window Size Params (M) FLOPs (G) FPS box AP Config Download
LITv2-S 2 38 242 18.7 44.0 config model & log
LITv2-S 4 38 230 20.4 43.7 config model & log
LITv2-M 2 59 348 12.2 46.0 config model & log
LITv2-M 4 59 312 14.8 45.8 config model & log
LITv2-B 2 97 481 9.5 46.7 config model & log
LITv2-B 4 97 430 11.8 46.3 config model & log

Mask R-CNN

Backbone Window Size Params (M) FLOPs (G) FPS box AP mask AP Config Download
LITv2-S 2 47 261 18.7 44.9 40.8 config model & log
LITv2-S 4 47 249 21.9 44.7 40.7 config model & log
LITv2-M 2 68 367 12.6 46.8 42.3 config model & log
LITv2-M 4 68 315 16.0 46.5 42.0 config model & log
LITv2-B 2 106 500 9.3 47.3 42.6 config model & log
LITv2-B 4 106 449 11.5 46.8 42.3 config model & log

Semantic Segmentation on ADE20K

All models are trained with 80K iterations with a total batch size of 16 on 8 V100 GPUs.

Backbone Params (M) FLOPs (G) FPS mIoU Config Download
LITv2-S 31 41 42.6 44.3 config model & log
LITv2-M 52 63 28.5 45.7 config model & log
LITv2-B 90 93 27.5 47.2 config model & log

Benchmarking Throughput on More GPUs

Model Params (M) FLOPs (G) A100 V100 RTX 6000 RTX 3090 Top-1 (%)
ResNet-50 26 4.1 1,424 1,123 877 1,279 80.4
PVT-S 25 3.8 1,460 798 548 1,007 79.8
Twins-PCPVT-S 24 3.8 1,455 792 529 998 81.2
Swin-Ti 28 4.5 1,564 1,039 710 961 81.3
TNT-S 24 5.2 802 431 298 534 81.3
CvT-13 20 4.5 1,595 716 379 947 81.6
CoAtNet-0 25 4.2 1,538 962 643 1,151 81.6
CaiT-XS24 27 5.4 991 484 299 623 81.8
PVTv2-B2 25 4.0 1,175 670 451 854 82.0
XCiT-S12 26 4.8 1,727 761 504 1,068 82.0
ConvNext-Ti 28 4.5 1,654 762 571 1,079 82.1
Focal-Tiny 29 4.9 471 372 261 384 82.2
LITv2-S 28 3.7 1,874 1,304 928 1,471 82.0

Single Attention Layer Benchmark

The following visualization results can refer to vit-attention-benchmark.

hilo_cpu_gpu

Citation

If you use LITv2 in your research, please consider the following BibTeX entry and giving us a star 🌟.

@inproceedings{pan2022hilo,
  title={Fast Vision Transformers with HiLo Attention},
  author={Pan, Zizheng and Cai, Jianfei and Zhuang, Bohan},
  booktitle={NeurIPS},
  year={2022}
}

If you find the code useful, please also consider the following BibTeX entry

@inproceedings{pan2022litv1,
  title={Less is More: Pay Less Attention in Vision Transformers},
  author={Pan, Zizheng and Zhuang, Bohan and He, Haoyu and Liu, Jing and Cai, Jianfei},
  booktitle={AAAI},
  year={2022}
}

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgement

This repository is built upon DeiT, Swin and LIT, we thank the authors for their open-sourced code.

More Repositories

1

SN-Net

[CVPR 2023 Highlight] This is the official implementation of "Stitchable Neural Networks".
Python
238
star
2

Mesa

This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".
Python
119
star
3

SPViT

[TPAMI 2024] This is the official repository for our paper: ''Pruning Self-attentions into Convolutional Layers in Single Path''.
Python
104
star
4

LIT

[AAAI 2022] This is the official PyTorch implementation of "Less is More: Pay Less Attention in Vision Transformers"
Python
88
star
5

PTQD

The official implementation of PTQD: Accurate Post-Training Quantization for Diffusion Models
Jupyter Notebook
85
star
6

QTool

Collections of model quantization algorithms. Any issues, please contact Peng Chen ([email protected])
Python
68
star
7

EcoFormer

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "EcoFormer: Energy-Saving Attention with Linear Complexity"
Python
66
star
8

SPT

[ICCV 2023 oral] This is the official repository for our paper: ''Sensitivity-Aware Visual Parameter-Efficient Fine-Tuning''.
Python
60
star
9

FASeg

[CVPR 2023] This is the official PyTorch implementation for "Dynamic Focus-aware Positional Queries for Semantic Segmentation".
Python
54
star
10

SAQ

This is the official PyTorch implementation for "Sharpness-aware Quantization for Deep Neural Networks".
Python
40
star
11

LongVLM

Python
38
star
12

HVT

[ICCV 2021] Official implementation of "Scalable Vision Transformers with Hierarchical Pooling"
Python
30
star
13

MPVSS

Python
25
star
14

SN-Netv2

[ECCV 2024] This is the official implementation of "Stitched ViTs are Flexible Vision Backbones".
Python
22
star
15

QLLM

[ICLR 2024] This is the official PyTorch implementation of "QLLM: Accurate and Efficient Low-Bitwidth Quantization for Large Language Models"
Python
19
star
16

efficient-stable-diffusion

16
star
17

Stitched_LLaMA

[CVPR 2024] A framework to fine-tune LLaMAs on instruction-following task and get many Stitched LLaMAs with customized number of parameters, e.g., Stitched LLaMA 8B, 9B, and 10B...
8
star
18

STPT

3
star
19

ZipLLM

1
star