• Stars
    star
    250
  • Rank 162,397 (Top 4 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 4 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch Implementation of OpenAI's Image GPT

Image GPT

PyTorch implementation of Image GPT, based on paper Generative Pretraining from Pixels (Chen et al.) and accompanying code.


Model-generated completions of half-images from test set. First column is input; last column is original image


iGPT-S pretrained on CIFAR10. Completions are fairly poor as the model was only trained on CIFAR10, not all of ImageNet.

WIP

  • Batched k-means on GPU for quantization of larger datasets (currently using sklearn.cluster.MiniBatchKMeans.)
  • BERT-style pretraining (currently only generative is supported.)
  • Load pretrained models from OpenAI.
  • Reproduce at least iGPT-S results.

According to their blog post, the largest model, iGPT-L (1.4 M parameters), was trained for 2500 V100-days. By greatly reducing the number of attention head, number of layers, and input size (which effects model size quadratically), we can train our own model (26 K parameters) on Fashion-MNIST on a single NVIDIA 2070 in less than 2 hours.

Usage

Pre-trained Models

Some pre-trained models are located in models directory. Run ./download.sh to download the cifar10 pretrained iGPT-S model.

Compute Centroids

Images are downloaded, and centroids are computed using k-means with num_clusters clusters. These centroids are used to quantize the images before they are fed into the model.

# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8

# creates data/<dataset>_centroids.npy

Note: Use the same num_clusters as num_vocab in your model.

Training

Models can be trained using src/run.py with the train subcommand.

Generative Pre-training

Models can be pretrained by specifying a dataset and model config. configs/s_gen.yml corresponds to iGPT-S from the paper, configs/xxs_gen.yml is an extra small model for trying on toy datasets with limited compute.

python src/run.py --dataset mnist train configs/xxs_gen.yml

Classification Fine-tuning

Pre-trained models can be fine-tuned by passing the path to the pre-trained checkpoint to --pretrained, along with the config file and dataset.

python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt`

Sampling

Figures like those seen above can be created using random images from test set:

# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt

Gifs like the one seen in my tweet can be made like so:

# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt

More Repositories

1

torchsort

Fast, differentiable sorting and ranking in PyTorch
Python
756
star
2

pedalnet

Deep Learning for Guitar Effect Emulation
Python
333
star
3

blog

Source code for my personal blog
Jupyter Notebook
180
star
4

survivorship-free-spy

Python
108
star
5

cryptopunks-gan

Simple SN-GAN to generate CryptoPunks
Python
71
star
6

unsupervised-deep-homography

PyTorch implementation of Unsupervised Deep Homography: https://arxiv.org/abs/1709.03966
Python
61
star
7

tinyloader

Python
54
star
8

u-noise

Official PyTorch code for U-Noise: Learnable Noise Masks for Interpretable Image Segmentation (ICIP 2021)
Python
39
star
9

performer

Simply Numpy implementation of the FAVOR+ attention mechanism, https://teddykoker.com/2020/11/performers/
Python
36
star
10

evidential-learning-pytorch

Evidential Deep Learning in PyTorch
Python
35
star
11

grokking

PyTorch implementation of "Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets"
Python
30
star
12

learning-to-learn-jax

JAX implementation of Learning to learn by gradient descent by gradient descent
Python
25
star
13

bookflip

Textbook trading website built with Flask and Vue.js
Python
20
star
14

image-forensics

PyTorch implementation of ImageForensics: https://hms-idac.github.io/ImageForensics/
Python
12
star
15

mpnn-for-quantum-chem

Pytorch implementation of MPNN for Quantum Chemistry
Python
12
star
16

e3nn.c

Pure C implementation of e3nn
C
11
star
17

trumpy

A command line program for impersonating people on Twitter, in this case Donald Trump.
Python
6
star
18

pipedal

HTML
4
star
19

go-react

Boilerplate to get started with React and Golang
Go
3
star
20

crappermapper

A website for rating and locating toilets.
JavaScript
3
star
21

lift

A new way to share fitness programs and measure progress. Golang + React + MongoDB
JavaScript
2
star
22

representme

A simple application to see your local representives based on location.
Objective-C
2
star
23

thenetwork

Highschool capstone project.
PHP
2
star
24

webswitch

Control LEDs over web sockets.
HTML
2
star
25

photocell-graph

Graph a live feed of RCtime readings from a photocell over web sockets.
HTML
2
star
26

teddykoker

1
star
27

llama-rs-server

Rust
1
star
28

nodebike

JavaScript
1
star