• Stars
    star
    113
  • Rank 310,115 (Top 7 %)
  • Language
    Python
  • Created over 3 years ago
  • Updated about 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Pretrained GANs in PyTorch: StyleGAN2, BigGAN, BigBiGAN, SAGAN, SNGAN, SelfCondGAN, and more

PyTorch Pretrained GANs

Quick Start

This repository provides a standardized interface for pretrained GANs in PyTorch. You can install it with:

pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

It is then easy to generate an image with a GAN:

import torch
from pytorch_pretrained_gans import make_gan

# Sample a class-conditional image from BigGAN with default resolution 256
G = make_gan(gan_type='biggan')  # -> nn.Module
y = G.sample_class(batch_size=1)  # -> torch.Size([1, 1000])
z = G.sample_latent(batch_size=1)  # -> torch.Size([1, 128])
x = G(z=z, y=y)  # -> torch.Size([1, 3, 256, 256])

Motivation

Over the past few years, great progress has been made in generative modeling using GANs. As a result, a large body of research has emerged that uses GANs and explores/interprets their latent spaces. I recently worked on a project in which I wanted to apply the same technique to a bunch of different GANs (here's the paper if you're interested). This was quite a pain because all the pretrained GANs out there are in completely different formats. So I decided to standardize them, and here's the result. I hope you find it useful.

Installation

Install with pip directly from GitHub:

pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

Available GANs

The following GANs are available. If you would like to add a new GAN to the repo, please submit a pull request -- I would love to add to this list:

Structure

This repo supports both conditional and unconditional GANs. The standard GAN interface is as follows:

class GeneratorWrapper(torch.nn.Module):
    """ A wrapper to put the GAN in a standard format."""

    def __init__(self, G, num_classes=None):
        super().__init__()
        self.G : nn.Module =    # GAN generator
        self.dim_z : int =      # dimensionality of latent space
        self.conditional =      # True / False

    def forward(self, z, y=None):  # y is for conditional GAN only
        x =  # ... generate image from latent with self.G
        return x  # returns image

    def sample_latent(self, batch_size, device='cpu'):
        z =  # ... samples latent vector of size self.dim_z
        return z

    def sample_class(self, batch_size, device='cpu'):
        y =  # ... samples class y (for conditional GAN only)
        return y

Each type of GAN is contained in its own folder and has a make_GAN_TYPE function. For example, make_bigbigan creates a BigBiGAN with the format of the GeneratorWrapper above.

The weights of all GANs except those in PyTorch-StudioGAN and are downloaded automatically. To download the PyTorch-StudioGAN weights, use the download.sh scripts in the corresponding folders (see the file structure below).

Code Structure

The structure of the repo is below. Each type of GAN has an __init__.py file that defines its GeneratorWrapper and its make_GAN_TYPE file.

pytorch_pretrained_gans
β”œβ”€β”€ __init__.py
β”œβ”€β”€ BigBiGAN
β”‚Β Β  β”œβ”€β”€ __init__.py
β”‚Β Β  β”œβ”€β”€ ...
β”‚Β Β  └── weights
β”‚Β Β   Β Β  └── download.sh   # (use this to download pretrained weights)
β”œβ”€β”€ BigGAN
β”‚Β Β  β”œβ”€β”€ __init__.py   # (pretrained weights are auto-downloaded)
β”‚Β Β  β”œβ”€β”€ ...
β”œβ”€β”€ StudioGAN
β”‚Β Β  β”œβ”€β”€ __init__.py   # (pretrained weights are auto-downloaded)
β”‚Β Β  β”œβ”€β”€ ...
β”‚Β Β  β”œβ”€β”€ configs
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ ImageNet
β”‚Β Β  β”‚Β Β  β”‚Β Β  β”œβ”€β”€ BigGAN2048
β”‚Β Β  β”‚Β Β  β”‚Β Β  β”‚Β Β  └── ...
β”‚Β Β  β”‚Β Β  β”‚Β Β  └── download.sh  # (use this to download pretrained weights)
β”‚Β Β  β”‚Β Β  └── TinyImageNet
β”‚Β Β  β”‚Β Β      β”œβ”€β”€ ACGAN
β”‚Β Β  β”‚Β Β      β”‚Β Β  └── ACGAN.json
β”‚Β Β  β”‚Β Β      β”œβ”€β”€ ...
β”‚Β Β  β”‚Β Β      └── download.sh  # (use this to download pretrained weights)
β”œβ”€β”€ self_conditioned
β”‚Β Β  β”œβ”€β”€ __init__.py   # (pretrained weights are auto-downloaded)
β”‚Β Β  └── ...
└── stylegan2_ada_pytorch
 Β Β  β”œβ”€β”€ __init__.py   # (pretrained weights are auto-downloaded)
 Β Β  └── ...

GAN-Specific Details

Naturally, there are some details that are specific to certain GANs.

BigGAN: For BigGAN, you should specify a resolution with model_name. For example:

  • G = make_gan(gan_type='biggan', model_name='biggan-deep-512')

StudioGAN: For StudioGAN, you should specify a model with model_name. For example:

  • G = make_gan(gan_type='studiogan', model_name='SAGAN')
  • G = make_gan(gan_type='studiogan', model_name='ContraGAN256')

Self-Conditioned GAN: For StudioGAN, you should specify a model (either self_conditioned or unconditional) with model_name. For example:

  • G = make_gan(gan_type='selfconditionedgan', model_name='self_conditioned')

StyleGAN 2:

  • StyleGAN2's sample_latent method returns w, not z, because this is usually what is desired. w has shape torch.Size([1, 18, 512]).
  • StyleGAN2 is currently not implemented on CPU

Citation

Please cite the following if you use this repo in a research paper:

@inproceedings{melaskyriazi2021finding,
  author    = {Melas-Kyriazi, Luke and Rupprecht, Christian and Laina, Iro and Vedaldi, Andrea},
  title     = {Finding an Unsupervised Image Segmenter in each of your Deep Generative Models},
  booktitle = arxiv,
  year      = {2021}
}

More Repositories

1

EfficientNet-PyTorch

A PyTorch implementation of EfficientNet and EfficientNetV2 (coming soon!)
Python
7,614
star
2

PyTorch-Pretrained-ViT

Vision Transformer (ViT) in PyTorch
Python
681
star
3

realfusion

Python
504
star
4

do-you-even-need-attention

Exploring whether attention is necessary for vision transformers
Python
476
star
5

deep-spectral-segmentation

[CVPR 2022] Deep Spectral Methods: A Surprisingly Strong Baseline for Unsupervised Semantic Segmentation and Localization
Python
217
star
6

projection-conditioned-point-cloud-diffusion

Official code for "Projection-Conditioned Point Cloud Diffusion for Single-Image 3D Reconstruction"
Python
119
star
7

Automatic-Image-Colorization

Automatic image colorization with a deep convolutional neural network
Python
113
star
8

image-paragraph-captioning

[EMNLP 2018] Training for Diversity in Image Paragraph Captioning
Python
85
star
9

unsupervised-image-segmentation

[ICLR 2022] Finding an Unsupervised Image Segmenter in each of your Deep Generative Models
Python
73
star
10

simple-bert

A simple PyTorch implementation of BERT, complete with pretrained models and training scripts.
Python
37
star
11

pixmatch

Python
35
star
12

mtob

Shell
24
star
13

Poker-Bot-with-Genetic-Algorithms

A final project for Math 153 (Evolutionary Dynamics) at Harvard University
Python
23
star
14

Machine-Translation

Machine translation with recurrent neural networks
Python
7
star
15

lukemelas.github.io

HTML
3
star
16

1000-Genomes-Project-Analysis

An analysis of the 1000 Genomes Project in PyTorch
Jupyter Notebook
2
star
17

lukemelas.github.io-src

HTML
2
star
18

CS-282-Project

CS 282 Project
Jupyter Notebook
2
star
19

Language-Modeling

Language modeling on the Penn Treebank dataset
Python
2
star
20

CS-222-Fall-2018

Problem sets from CS 222 Fall 2018
Jupyter Notebook
1
star
21

CS-282-Fall-2018

Problem sets from CS 282 Fall 2018
Jupyter Notebook
1
star
22

CS-244-Project

Like blockchain but deeper
1
star
23

CS-222-Project

CS 222 Project
Jupyter Notebook
1
star
24

tada

A repository for translation as data augmentation
1
star