LocalViT: Bringing Locality to Vision Transformers
This repository contains the PyTorch training and evaluation code for LocalViT.
LocalViT can consistently improve the performance of current Vision Transformers:
If you use this code for a paper please cite:
@article{li2021localvit,
title={LocalViT: Bringing Locality to Vision Transformers},
author={Li, Yawei and Zhang, Kai and Cao, Jiezhang and Timofte, Radu and Van Gool, Luc},
journal={arXiv preprint arXiv:2104.05707},
year={2021}
}
The repository is based on the timm package by Ross Wightman and Deit by Hugo Touvron.
Update
Swin Transformer is added in the experiment. When training under the same training protocol, LocalViT-Swin outperforms Swin-T by 1.0% in terms of Top-1 Accuracy.
1. Model Zoo
The pre-trained models on ImageNet 2012 are provided.
Model | Top1 acc % | Top 5 acc % | #Params | Download |
---|---|---|---|---|
LocalViT-T | 74.84 | 92.61 | 5.7M | model |
LocalViT-T-SE4 | 75.74 | 93.05 | 9.4M | model |
LocalViT-S | 80.78 | 95.38 | 22.4M | model |
LocalViT-PVT | 78.14 | 94.24 | 13.5M | model |
LocalViT-TNT | 75.90 | 92.90 | 6.3M | model |
LocalViT-Swin | 81.86 | 95.72 | 29.1M | model |
SE4 means that the hidden dimension in the SE module is reduced by 4. See Table 2 in the paper.
2. Usage
I. Clone the repository locally:
git clone https://github.com/ofsoundof/LocalViT.git
II. Install pytorch-image-models 0.3.2:
pip install timm==0.3.2
Data preparation
Download and extract ImageNet train and val images.
The directory structure is the standard layout for the torchvision datasets.ImageFolder
as follows:
βimagenet/
βββtrain/
β βββ n01440764
β β βββ n01440764_18.JPEG
β β βββ ......
β βββ ......
βββval/
β βββ n01440764
β β βββ ILSVRC2012_val_00000293.JPEG
β β βββ ......
β βββ ......
3. Evaluation
To evaluate LocalViT-T pre-trained on ImageNet with a single GPU:
python main.py --model localvit_tiny_mlp4_act3_r192 --eval --resume /path/to/localvit_t.pth --data-path /path/to/imagenet
This should give
* Acc@1 74.838 Acc@5 92.610 loss 1.211
Evaluating the other models.
python main.py --model localvit_tiny_mlp4_act3_r4 --eval --resume /path/to/localvit_t_se4.pth --data-path /path/to/imagenet
This should give
* Acc@1 75.738 Acc@5 93.048 loss 1.330
python main.py --model localvit_small --eval --resume /path/to/localvit_s.pth --data-path /path/to/imagenet
This should give
* Acc@1 80.780 Acc@5 95.376 loss 1.019
python main.py --model localvit_tnt_t_patch16_224 --eval --resume /path/to/localvit_tnt.pth --data-path /path/to/imagenet
This should give
* Acc@1 75.896 Acc@5 92.898 loss 1.229
python main.py --model localvit_pvt_tiny --eval --resume /path/to/localvit_pvt.pth --data-path /path/to/imagenet
This should give
* Acc@1 78.144 Acc@5 94.238 loss 1.058
python main.py --model localvit_swin_tiny_patch4_window7_224 --eval --resume /path/to/localvit_swin.pth --data-path /path/to/imagenet
This should give
* Acc@1 81.860 Acc@5 95.720 loss 1.109
4. Training
Train LocalViT-T on ImageNet on a single node with 8 GPUs for 300 epochs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model localvit_tiny_mlp4_act3_r192 --batch-size 128 --data-path /path/to/imagenet --output_dir /path/to/save
5. How to introduce the locality mechanism.
In order to introduce the locality mechanism into existing vision transformers, there are two steps.
I. Replace the MLP layer with LocalityFeedForward
.
II. Change the computation procedure accordingly.
1) Split the class token and the image token.
2) Reshape and update the image token.
3) Concatenate the class token and the updated image token.
The following example show how to introduce locality mechanism into the orginal Transformer block by Ross Wightman.
import math
import torch
import torch.nn as nn
from timm.models.layers import DropPath
from timm.models.vision_transformer import Attention
from models.localvit import LocalityFeedForward
class TransformerLayer(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
#########################################
# Original implementation
# self.norm2 = norm_layer(dim)
# mlp_hidden_dim = int(dim * mlp_ratio)
# self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
#########################################
# Replace the MLP layer by LocalityFeedForward.
self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, act='hs+se', reduction=dim//4)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
#########################################
# Original implementation
# x = x + self.drop_path(self.mlp(self.norm2(x)))
#########################################
# Change the computation accordingly in three steps.
batch_size, num_token, embed_dim = x.shape
patch_size = int(math.sqrt(num_token))
# 1. Split the class token and the image token.
cls_token, x = torch.split(x, [1, num_token - 1], dim=1)
# 2. Reshape and update the image token.
x = x.transpose(1, 2).view(batch_size, embed_dim, patch_size, patch_size)
x = self.conv(x).flatten(2).transpose(1, 2)
# 3. Concatenate the class token and the newly computed image token.
x = torch.cat([cls_token, x], dim=1)
return x