• Stars
    star
    476
  • Rank 92,280 (Top 2 %)
  • Language
    Python
  • Created over 3 years ago
  • Updated over 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Exploring whether attention is necessary for vision transformers

Do You Even Need Attention?
A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet

Paper/Report

TL;DR

We replace the attention layer in a vision transformer with a feed-forward layer and find that it still works quite well on ImageNet.

Abstract

The strong performance of vision transformers on image classification and other vision tasks is often attributed to the design of their multi-head attention layers. However, the extent to which attention is responsible for this strong performance remains unclear. In this short report, we ask: is the attention layer even necessary? Specifically, we replace the attention layer in a vision transformer with a feed-forward layer applied over the patch dimension. The resulting architecture is simply a series of feed-forward layers applied over the patch and feature dimensions in an alternating fashion. In experiments on ImageNet, this architecture performs surprisingly well: a ViT/DeiT-base-sized model obtains 74.9% top-1 accuracy, compared to 77.9% and 79.9% for ViT and DeiT respectively. These results indicate that aspects of vision transformers other than attention, such as the patch embedding, may be more responsible for their strong performance than previously thought. We hope these results prompt the community to spend more time trying to understand why our current models are as effective as they are.

Note

This is concurrent research with MLP-Mixer from Google Research. The ideas are exacty the same, with the one difference being that they use (a lot) more compute.

Pretrained models and logs

Here is a Weights and Biases report with the expected training trajectory: W&B

name acc@1 #params url
FF-tiny 61.4 7.7M model
FF-base 74.9 62M model
FF-large 71.4 206M -

Note: I haven't uploaded the FF-Large model because (1) it's over GitHub's file storage limit, and (2) I don't see why anyone would want it, given that it performs worse than the base model. That being said, if you want it, reach out to me and I'll send it to you.

How to train

The model definition in vision_transformer_linear.py is designed to be run with the repo from DeiT, which is itself based on the wonderful timm package.

Steps:

  • Clone the DeiT repo and move the file into it
git clone https://github.com/facebookresearch/deit
mv vision_transformer_linear.py deit
cd deit
  • Add a line to import vision_transformer_linear in main.py. For example, add the following after the import statements (around line 27):
+ import vision_transformer_linear
  • Train:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \
--nproc_per_node=8 \
--master_port 10490 \
--use_env main.py \
--model linear_tiny \
--batch-size 128 \
--drop 0.1 \
--output_dir outputs/linear-tiny \
--data-path your/path/to/imagenet

Citation

If you build upon this idea, feel free to drop a citation (and also cite MLP-Mixer).

@article{melaskyriazi2021doyoueven,
  title={Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet},
  author={Luke Melas-Kyriazi},
  journal=arxiv,
  year=2021
}

More Repositories

1

EfficientNet-PyTorch

A PyTorch implementation of EfficientNet and EfficientNetV2 (coming soon!)
Python
7,614
star
2

PyTorch-Pretrained-ViT

Vision Transformer (ViT) in PyTorch
Python
681
star
3

realfusion

Python
504
star
4

deep-spectral-segmentation

[CVPR 2022] Deep Spectral Methods: A Surprisingly Strong Baseline for Unsupervised Semantic Segmentation and Localization
Python
217
star
5

projection-conditioned-point-cloud-diffusion

Official code for "Projection-Conditioned Point Cloud Diffusion for Single-Image 3D Reconstruction"
Python
119
star
6

pytorch-pretrained-gans

Pretrained GANs in PyTorch: StyleGAN2, BigGAN, BigBiGAN, SAGAN, SNGAN, SelfCondGAN, and more
Python
113
star
7

Automatic-Image-Colorization

Automatic image colorization with a deep convolutional neural network
Python
113
star
8

image-paragraph-captioning

[EMNLP 2018] Training for Diversity in Image Paragraph Captioning
Python
85
star
9

unsupervised-image-segmentation

[ICLR 2022] Finding an Unsupervised Image Segmenter in each of your Deep Generative Models
Python
73
star
10

simple-bert

A simple PyTorch implementation of BERT, complete with pretrained models and training scripts.
Python
37
star
11

pixmatch

Python
35
star
12

mtob

Shell
24
star
13

Poker-Bot-with-Genetic-Algorithms

A final project for Math 153 (Evolutionary Dynamics) at Harvard University
Python
23
star
14

Machine-Translation

Machine translation with recurrent neural networks
Python
7
star
15

lukemelas.github.io

HTML
3
star
16

1000-Genomes-Project-Analysis

An analysis of the 1000 Genomes Project in PyTorch
Jupyter Notebook
2
star
17

lukemelas.github.io-src

HTML
2
star
18

CS-282-Project

CS 282 Project
Jupyter Notebook
2
star
19

Language-Modeling

Language modeling on the Penn Treebank dataset
Python
2
star
20

CS-222-Fall-2018

Problem sets from CS 222 Fall 2018
Jupyter Notebook
1
star
21

CS-282-Fall-2018

Problem sets from CS 282 Fall 2018
Jupyter Notebook
1
star
22

CS-244-Project

Like blockchain but deeper
1
star
23

CS-222-Project

CS 222 Project
Jupyter Notebook
1
star
24

tada

A repository for translation as data augmentation
1
star