• Stars
    star
    236
  • Rank 165,377 (Top 4 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 1 year ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[CVPR 2023 Highlight] This is the official implementation of "Stitchable Neural Networks".

Stitchable Neural Networks πŸͺ‘ (CVPR 2023 Highlight)

License PyTorch

This is the official PyTorch implementation of Stitchable Neural Networks.

By Zizheng Pan, Jianfei Cai, and Bohan Zhuang.

News

  • 28/03/2023. Code for stitching LeViTs has been released.
  • 27/03/2023. We release the code and checkpoints for stitching ResNets and Swin Transformers.
  • 22/03/2023. SN-Net was selected as a highlight at CVPR 2023!πŸ”₯
  • 02/03/2023. We release the source code! Any issues are welcomed!
  • 28/02/2023. SN-Net was accepted by CVPR 2023! πŸŽ‰πŸŽ‰πŸŽ‰

A Gentle Introduction

Stitchable Neural Network (SN-Net) is a novel scalable and efficient framework for model deployment which cheaply produces numerous networks with different complexity and performance trade-offs given a family of pretrained neural networks, which we call anchors. Specifically, SN-Net splits the anchors across the blocks/layers and then stitches them together with simple stitching layers to map the activations from one anchor to another.

With only a few epochs of training, SN-Net effectively interpolates between the performance of anchors with varying scales. At runtime, SN-Net can instantly adapt to dynamic resource constraints by switching the stitching positions.

Getting Started

SN-Net is a general framework. However, as different model families are trained differently, we use their own code for stitching experiments. In this repo, we provide examples for several model families, such as plain ViTs, hierarchical ViTs, CNNs, CNN-ViT, and lightweight ViTs.

To use our repo, we suggest creating a Python virtual environment.

conda create -n snnet python=3.9
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
pip install fvcore
pip install timm==0.6.12

Next, you can feel free to experiment with different settings.

For experiments with plain ViTs, please refer to stitching_deit.

For experiments with hierarchical ViTs, please refer to stitching_swin.

For experiments with CNNs and CNN-ViT, please refer to stitching_resnet_swin.

For experiments with lightweight ViTs, please refer to stitching_levit.

Best Practice for Extension

Please feel free to extend SN-Net into other model familiy. The following tips may help your experiments.

For Better Stitching

  1. For paired stitching (equal depth) such as on plain ViTs, using a small sliding window for stitching usually achieves a smoother performance curve.
  2. For unpaired stitching (unequal depth) such as on hierarchical ViTs, split the architecture into different stages and stitch within the same stage.
  3. Note that many existing models allocate most blocks/layers into the 3rd stage, thus stitching at the 3rd stage can help to obtain more stitches.
  4. Remember to initialize your stitching layers. A few samples can be enough.

For Better Training

  1. Uniformly decreasing the learning rate (the training time LR) by 10x can serve as a good starting point. See our settings in DeiT-based experiments.
  2. If the above is not good, try to decrease the learning rate for anchors while using a relatively larger learning rate for stitching layers. See our Swin-based experiments.
  3. Training with more epochs (e.g., 100) can be better, but it also comes at a higher computational cost.

Citation

If you use SN-Net in your research, please consider the following BibTeX entry and giving us a star🌟!

@inproceedings{pan2023snnet,
  title={Stitchable Neural Networks},
  author={Pan, Zizheng and Cai, Jianfei and Zhuang, Bohan},
  booktitle={CVPR},
  year={2023}
}

Acknowledgement

This implementation is built upon DeiT and Swin. We thank the authors for their released code.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

More Repositories

1

LITv2

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "Fast Vision Transformers with HiLo Attention"
Python
193
star
2

Mesa

This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".
Python
119
star
3

SPViT

[TPAMI 2024] This is the official repository for our paper: ''Pruning Self-attentions into Convolutional Layers in Single Path''.
Python
101
star
4

LIT

[AAAI 2022] This is the official PyTorch implementation of "Less is More: Pay Less Attention in Vision Transformers"
Python
83
star
5

PTQD

The official implementation of PTQD: Accurate Post-Training Quantization for Diffusion Models
Jupyter Notebook
77
star
6

EcoFormer

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "EcoFormer: Energy-Saving Attention with Linear Complexity"
Python
66
star
7

QTool

Collections of model quantization algorithms. Any issues, please contact Peng Chen ([email protected])
Python
64
star
8

FASeg

[CVPR 2023] This is the official PyTorch implementation for "Dynamic Focus-aware Positional Queries for Semantic Segmentation".
Python
54
star
9

SPT

[ICCV 2023 oral] This is the official repository for our paper: ''Sensitivity-Aware Visual Parameter-Efficient Fine-Tuning''.
Python
54
star
10

SAQ

This is the official PyTorch implementation for "Sharpness-aware Quantization for Deep Neural Networks".
Python
39
star
11

HVT

[ICCV 2021] Official implementation of "Scalable Vision Transformers with Hierarchical Pooling"
Python
30
star
12

SN-Netv2

This is the official implementation of "Stitched ViTs are Flexible Vision Backbones".
Python
21
star
13

LongVLM

17
star
14

MPVSS

Python
17
star
15

efficient-stable-diffusion

16
star
16

QLLM

[ICLR 2024] This is the official PyTorch implementation of "QLLM: Accurate and Efficient Low-Bitwidth Quantization for Large Language Models"
Python
13
star
17

Stitched_LLaMA

[CVPR 2024] A framework to fine-tune LLaMAs on instruction-following task and get many Stitched LLaMAs with customized number of parameters, e.g., Stitched LLaMA 8B, 9B, and 10B...
5
star
18

STPT

3
star