ViT PyTorch
Quickstart
Install with pip install pytorch_pretrained_vit
and load a pretrained ViT with:
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
Or find a Google Colab example here.
Overview
This repository contains an op-for-op PyTorch reimplementation of the Visual Transformer architecture from Google, along with pre-trained models and examples.
The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects.
At the moment, you can easily:
- Load pretrained ViT models
- Evaluate on ImageNet or your own data
- Finetune ViT on your own dataset
(Upcoming features) Coming soon:
- Train ViT from scratch on ImageNet (1K)
- Export to ONNX for efficient inference
Table of contents
About ViT
Visual Transformers (ViT) are a straightforward application of the transformer architecture to image classification. Even in computer vision, it seems, attention is all you need.
The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like BERT), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits. ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute.
About ViT-PyTorch
ViT-PyTorch is a PyTorch re-implementation of ViT. It is consistent with the original Jax implementation, so that it's easy to load Jax-pretrained weights.
At the same time, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.
Installation
Install with pip:
pip install pytorch_pretrained_vit
Or from source:
git clone https://github.com/lukemelas/ViT-PyTorch
cd ViT-Pytorch
pip install -e .
Usage
Loading pretrained models
Loading a pretrained model is easy:
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
Details about the models are below:
Name | * Pretrained on * | Finetuned on | *Available? * |
---|---|---|---|
B_16 |
ImageNet-21k | - | ✓ |
B_32 |
ImageNet-21k | - | ✓ |
L_16 |
ImageNet-21k | - | - |
L_32 |
ImageNet-21k | - | ✓ |
B_16_imagenet1k |
ImageNet-21k | ImageNet-1k | ✓ |
B_32_imagenet1k |
ImageNet-21k | ImageNet-1k | ✓ |
L_16_imagenet1k |
ImageNet-21k | ImageNet-1k | ✓ |
L_32_imagenet1k |
ImageNet-21k | ImageNet-1k | ✓ |
Custom ViT
Loading custom configurations is just as easy:
from pytorch_pretrained_vit import ViT
# The following is equivalent to ViT('B_16')
config = dict(hidden_size=512, num_heads=8, num_layers=6)
model = ViT.from_config(config)
Example: Classification
Below is a simple, complete example. It may also be found as a Jupyter notebook in examples/simple
or as a Colab Notebook.
import json
from PIL import Image
import torch
from torchvision import transforms
# Load ViT
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
model.eval()
# Load image
# NOTE: Assumes an image `img.jpg` exists in the current directory
img = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5),
])(Image.open('img.jpg')).unsqueeze(0)
print(img.shape) # torch.Size([1, 3, 384, 384])
# Classify
with torch.no_grad():
outputs = model(img)
print(outputs.shape) # (1, 1000)
ImageNet
See examples/imagenet
for details about evaluating on ImageNet.
Credit
Other great repositories with this model include:
Contributing
If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.
I look forward to seeing what the community does with these models!