• Stars
    star
    321
  • Rank 130,752 (Top 3 %)
  • Language
    Python
  • Created almost 6 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch Implementation of InfoGAN

InfoGAN-PyTorch

PyTorch implementation of InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets with result of experiments on MNIST, FashionMNIST, SVHN and CelebA datasets.

Introduction

InfoGAN is an information-theoretic extension to the simple Generative Adversarial Networks that is able to learn disentangled representations in a completely unsupervised manner. What this means is that InfoGAN successfully disentangle wrirting styles from digit shapes on th MNIST dataset and discover visual concepts such as hair styles and gender on the CelebA dataset. To achieve this an information-theoretic regularization is added to the loss function that enforces the maximization of mutual information between latent codes, c, and the generator distribution G(z, c).

Folder structure

The following shows basic folder structure.

├── train.py # train script
├── data
│   ├── mnist # mnist data (not included in this repo)
│   ├── ...
│   ├── ...
│   └── fashion-mnist # fashion-mnist data (not included in this repo)
│
├── config.py # hyperparameters for training
├── utils.py # utils
├── dataloader.py # dataloader
├── models # infoGAN networks for different datasets
│   ├── mnist_model.py
│   ├── svhn_model.py
│   └── celeba_model.py
└── results # generation results to be saved here

Development Environment

  • Ubuntu 16.04 LTS
  • NVIDIA GeForce GTX 1060
  • cuda 9.0
  • Python 3.6.5
  • PyTorch 1.0.0
  • torchvision 0.2.1
  • numpy 1.14.3
  • matplotlib 2.2.2

Usage

Edit the config.py file to select training parameters and the dataset to use. Choose dataset from ['MNIST', 'FashionMNIST', 'SVHN', 'CelebA']

To train the model run train.py:

python3 train.py

After training the network to experiment with the latent code for the MNIST dataset run mnist_generate.py:

python3 mnist_generate.py --load_path /path/to/pth/checkpoint

Results

MNIST

Training Data Generation GIF
Epoch 1 Epoch 50 Epoch 100

Training Loss Curve:

Manipulating Latent Code

Rotation of digits.
Row represents categorical variable from K = 0 to K = 9 (top to buttom) to characterize digits. Column represents continuous variable varying from -2 to 2 (left to right).

Variation in Width
Row represents categorical variable from K = 0 to K = 9 (top to buttom) to characterize digits. Column represents continuous variable varying from -2 to 2 (left to right).

FashionMNIST

Training Data Generation GIF
Epoch 1 Epoch 50 Epoch 100

Training Loss Curve:

Manipulating Latent Code

Thickness of items.
Row represents categorical variable from K = 0 to K = 9 (top to buttom) to characterize items. Column represents continuous variable varying from -2 to 2 (left to right).

SVHN

Training Data Generation GIF
Epoch 1 Epoch 50 Epoch 100

Training Loss Curve:

Manipulating Latent Code

Continuous Variation: Lighting

Discrete Variation: Plate Context

CelebA

Training Data Generation GIF
Epoch 1 Epoch 50 Epoch 100

Training Loss Curve:

Manipulating Latent Code

Azimuth (pose)

Gender: Roughly ordered from male to female (left to right)

Emotion

Hair Style and Color

Hair Quantity: Roughly ordered from less hair to more hair (left to right)

References

  1. Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel. InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets. [arxiv]
  2. pianomania/infoGAN-pytorch [repo]