• Stars
    star
    681
  • Rank 66,346 (Top 2 %)
  • Language
    Python
  • Created about 4 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Vision Transformer (ViT) in PyTorch

ViT PyTorch

Quickstart

Install with pip install pytorch_pretrained_vit and load a pretrained ViT with:

from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

Or find a Google Colab example here.

Overview

This repository contains an op-for-op PyTorch reimplementation of the Visual Transformer architecture from Google, along with pre-trained models and examples.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects.

At the moment, you can easily:

  • Load pretrained ViT models
  • Evaluate on ImageNet or your own data
  • Finetune ViT on your own dataset

(Upcoming features) Coming soon:

  • Train ViT from scratch on ImageNet (1K)
  • Export to ONNX for efficient inference

Table of contents

  1. About ViT
  2. About ViT-PyTorch
  3. Installation
  4. Usage
  5. Contributing

About ViT

Visual Transformers (ViT) are a straightforward application of the transformer architecture to image classification. Even in computer vision, it seems, attention is all you need.

The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like BERT), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits. ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute.

About ViT-PyTorch

ViT-PyTorch is a PyTorch re-implementation of ViT. It is consistent with the original Jax implementation, so that it's easy to load Jax-pretrained weights.

At the same time, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.

Installation

Install with pip:

pip install pytorch_pretrained_vit

Or from source:

git clone https://github.com/lukemelas/ViT-PyTorch
cd ViT-Pytorch
pip install -e .

Usage

Loading pretrained models

Loading a pretrained model is easy:

from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

Details about the models are below:

Name * Pretrained on * Finetuned on *Available? *
B_16 ImageNet-21k - ✓
B_32 ImageNet-21k - ✓
L_16 ImageNet-21k - -
L_32 ImageNet-21k - ✓
B_16_imagenet1k ImageNet-21k ImageNet-1k ✓
B_32_imagenet1k ImageNet-21k ImageNet-1k ✓
L_16_imagenet1k ImageNet-21k ImageNet-1k ✓
L_32_imagenet1k ImageNet-21k ImageNet-1k ✓

Custom ViT

Loading custom configurations is just as easy:

from pytorch_pretrained_vit import ViT
# The following is equivalent to ViT('B_16')
config = dict(hidden_size=512, num_heads=8, num_layers=6)
model = ViT.from_config(config)

Example: Classification

Below is a simple, complete example. It may also be found as a Jupyter notebook in examples/simple or as a Colab Notebook.

import json
from PIL import Image
import torch
from torchvision import transforms

# Load ViT
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
model.eval()

# Load image
# NOTE: Assumes an image `img.jpg` exists in the current directory
img = transforms.Compose([
    transforms.Resize((384, 384)), 
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open('img.jpg')).unsqueeze(0)
print(img.shape) # torch.Size([1, 3, 384, 384])

# Classify
with torch.no_grad():
    outputs = model(img)
print(outputs.shape)  # (1, 1000)

ImageNet

See examples/imagenet for details about evaluating on ImageNet.

Credit

Other great repositories with this model include:

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

More Repositories

1

EfficientNet-PyTorch

A PyTorch implementation of EfficientNet and EfficientNetV2 (coming soon!)
Python
7,614
star
2

realfusion

Python
504
star
3

do-you-even-need-attention

Exploring whether attention is necessary for vision transformers
Python
476
star
4

deep-spectral-segmentation

[CVPR 2022] Deep Spectral Methods: A Surprisingly Strong Baseline for Unsupervised Semantic Segmentation and Localization
Python
217
star
5

projection-conditioned-point-cloud-diffusion

Official code for "Projection-Conditioned Point Cloud Diffusion for Single-Image 3D Reconstruction"
Python
119
star
6

pytorch-pretrained-gans

Pretrained GANs in PyTorch: StyleGAN2, BigGAN, BigBiGAN, SAGAN, SNGAN, SelfCondGAN, and more
Python
113
star
7

Automatic-Image-Colorization

Automatic image colorization with a deep convolutional neural network
Python
113
star
8

image-paragraph-captioning

[EMNLP 2018] Training for Diversity in Image Paragraph Captioning
Python
85
star
9

unsupervised-image-segmentation

[ICLR 2022] Finding an Unsupervised Image Segmenter in each of your Deep Generative Models
Python
73
star
10

simple-bert

A simple PyTorch implementation of BERT, complete with pretrained models and training scripts.
Python
37
star
11

pixmatch

Python
35
star
12

mtob

Shell
24
star
13

Poker-Bot-with-Genetic-Algorithms

A final project for Math 153 (Evolutionary Dynamics) at Harvard University
Python
23
star
14

Machine-Translation

Machine translation with recurrent neural networks
Python
7
star
15

lukemelas.github.io

HTML
3
star
16

1000-Genomes-Project-Analysis

An analysis of the 1000 Genomes Project in PyTorch
Jupyter Notebook
2
star
17

lukemelas.github.io-src

HTML
2
star
18

CS-282-Project

CS 282 Project
Jupyter Notebook
2
star
19

Language-Modeling

Language modeling on the Penn Treebank dataset
Python
2
star
20

CS-222-Fall-2018

Problem sets from CS 222 Fall 2018
Jupyter Notebook
1
star
21

CS-282-Fall-2018

Problem sets from CS 282 Fall 2018
Jupyter Notebook
1
star
22

CS-244-Project

Like blockchain but deeper
1
star
23

CS-222-Project

CS 222 Project
Jupyter Notebook
1
star
24

tada

A repository for translation as data augmentation
1
star