• Stars
    star
    119
  • Rank 297,851 (Top 6 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created almost 3 years ago
  • Updated almost 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".

A Memory-saving Training Framework for Transformers

License

This is the official PyTorch implementation for Mesa: A Memory-saving Training Framework for Transformers.

By Zizheng Pan, Peng Chen, Haoyu He, Jing Liu, Jianfei Cai and Bohan Zhuang.

Installation

  1. Create a virtual environment with anaconda.

    conda create -n mesa python=3.7 -y
    conda activate mesa
    
    # Install PyTorch, we use PyTorch 1.7.1 with CUDA 10.1 
    pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
    
    # Install ninja
    pip install ninja
  2. Build and install Mesa.

    # clone this repo
    git clone https://github.com/zhuang-group/Mesa
    # build
    cd Mesa/
    # You need to have an NVIDIA GPU
    python setup.py develop

Usage

  1. Prepare your policy and save as a text file, e.g. policy.txt.

    on gelu: # layer tag, choices: fc, conv, gelu, bn, relu, softmax, matmul, layernorm
        by_index: all # layer index
        enable: True # enable for compressing
        level: 256 # we adopt 8-bit quantization by default
        ema_decay: 0.9 # the decay rate for running estimates
        
        by_index: 1 2 # e.g. exluding GELU layers that indexed by 1 and 2.
        enable: False
  2. Next, you can wrap your model with Mesa by:

    import mesa as ms
    ms.policy.convert_by_num_groups(model, 3)
    # or convert by group size with ms.policy.convert_by_group_size(model, 64)
    
    # setup compression policy
    ms.policy.deploy_on_init(model, '[path to policy.txt]', verbose=print, override_verbose=False)

    That's all you need to use Mesa for memory saving.

    Note that convert_by_num_groups and convert_by_group_size only recognize nn.XXX, if your code has functional operations, such as Q@K and F.Softmax, you may need to manually setup these layers. For example:

    import mesa as ms
    # matrix multipcation (before)
    out = Q@K.transpose(-2, -1)
    # with Mesa
    self.mm = ms.MatMul(quant_groups=3)
    out = self.mm(q, k.transpose(-2, -1))
    
    # sofmtax (before)
    attn = attn.softmax(dim=-1)
    # with Mesa
    self.softmax = ms.Softmax(dim=-1, quant_groups=3)
    attn = self.softmax(attn)
  3. You can also target one layer by:

    import mesa as ms
    # previous 
    self.act = nn.GELU()
    # with Mesa
    self.act = ms.GELU(quant_groups=[num of quantization groups])

Demo projects for DeiT and Swin

We provide demo projects to replicate our results of training DeiT and Swin with Mesa, please refer to DeiT-Mesa and Swin-Mesa.

Results on ImageNet

Model Param (M) FLOPs (G) Train Memory (MB) Top-1 (%)
DeiT-Ti 5 1.3 4,171 71.9
DeiT-Ti w/ Mesa 5 1.3 1,858 72.1
DeiT-S 22 4.6 8,459 79.8
DeiT-S w/ Mesa 22 4.6 3,840 80.0
DeiT-B 86 17.5 17,691 81.8
DeiT-B w/ Mesa 86 17.5 8,616 81.8
Swin-Ti 29 4.5 11,812 81.3
Swin-Ti w/ Mesa 29 4.5 5,371 81.3
PVT-Ti 13 1.9 7,800 75.1
PVT-Ti w/ Mesa 13 1.9 3,782 74.9

Memory footprint at training time is measured with a batch size of 128 and an image resolution of 224x224 on a single GPU.

Citation

If you find our work interesting or helpful to your research, please consider citing Mesa.

@article{pan2021mesa,
      title={Mesa: A Memory-saving Training Framework for Transformers}, 
      author={Zizheng Pan and Peng Chen and Haoyu He and Jing Liu and Jianfei Cai and Bohan Zhuang},
      journal={arXiv preprint arXiv:2111.11124}
      year={2021}
}

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgments

This repository has adopted part of the quantization codes from ActNN, we thank the authors for their open-sourced code.

More Repositories

1

SN-Net

[CVPR 2023 Highlight] This is the official implementation of "Stitchable Neural Networks".
Python
238
star
2

LITv2

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "Fast Vision Transformers with HiLo Attention"
Python
233
star
3

SPViT

[TPAMI 2024] This is the official repository for our paper: ''Pruning Self-attentions into Convolutional Layers in Single Path''.
Python
104
star
4

LIT

[AAAI 2022] This is the official PyTorch implementation of "Less is More: Pay Less Attention in Vision Transformers"
Python
88
star
5

PTQD

The official implementation of PTQD: Accurate Post-Training Quantization for Diffusion Models
Jupyter Notebook
85
star
6

QTool

Collections of model quantization algorithms. Any issues, please contact Peng Chen ([email protected])
Python
68
star
7

EcoFormer

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "EcoFormer: Energy-Saving Attention with Linear Complexity"
Python
66
star
8

SPT

[ICCV 2023 oral] This is the official repository for our paper: ''Sensitivity-Aware Visual Parameter-Efficient Fine-Tuning''.
Python
60
star
9

FASeg

[CVPR 2023] This is the official PyTorch implementation for "Dynamic Focus-aware Positional Queries for Semantic Segmentation".
Python
54
star
10

SAQ

This is the official PyTorch implementation for "Sharpness-aware Quantization for Deep Neural Networks".
Python
40
star
11

LongVLM

Python
38
star
12

HVT

[ICCV 2021] Official implementation of "Scalable Vision Transformers with Hierarchical Pooling"
Python
30
star
13

MPVSS

Python
25
star
14

SN-Netv2

[ECCV 2024] This is the official implementation of "Stitched ViTs are Flexible Vision Backbones".
Python
22
star
15

QLLM

[ICLR 2024] This is the official PyTorch implementation of "QLLM: Accurate and Efficient Low-Bitwidth Quantization for Large Language Models"
Python
19
star
16

efficient-stable-diffusion

16
star
17

Stitched_LLaMA

[CVPR 2024] A framework to fine-tune LLaMAs on instruction-following task and get many Stitched LLaMAs with customized number of parameters, e.g., Stitched LLaMA 8B, 9B, and 10B...
8
star
18

STPT

3
star
19

ZipLLM

1
star