• Stars
    star
    152
  • Rank 243,091 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 6 years ago
  • Updated almost 6 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation (with some experimentation) of the paper titled "VARIATIONAL DISCRIMINATOR BOTTLENECK: IMPROVING IMITATION LEARNING, INVERSE RL, AND GANS BY CONSTRAINING INFORMATION FLOW" (arxiv -> https://arxiv.org/pdf/1810.00821.pdf)

Variational_Discriminator_Bottleneck

Implementation (with some experimentation) of the paper titled "VARIATIONAL DISCRIMINATOR BOTTLENECK: IMPROVING IMITATION LEARNING, INVERSE RL, AND GANS BY CONSTRAINING INFORMATION FLOW" (arxiv -> https://arxiv.org/pdf/1810.00821.pdf)

Implementation uses the PyTorch framework.

VGAN architecture:

detailed_architecture


The core concept proposed by the paper is to enforce an Information Bottleneck between the Input images and the Discriminator’s internal representation of them.

As shown in the diagram, the Discriminator is divided into two parts now: An Encoder and the actual Discriminator. Note that the Generator is still the same. The Encoder is modelled using a ResNet similar in architecture to the Generator, while the Discriminator is a simple Linear classifier. Note that the Encoder doesn't output the internal codes of the images, but similar to a VAE’s encoder, gives the means and stds of the distributions from which samples are drawn and fed to discriminator.

CelebA 128x128 Experiment

I trained the VGAN-GP (just replace the normal GAN loss with WGAN-GP) on the CelebA dataset and the results are shown below.

generated samples


The value for Ic that I used is 0.2 as described in the paper and the architectures for G and D are also as described in the paper. The authors trained the model for 300K iterations, but the results that I displayed are at 62K iterations which took me 22.5 hours to train. I will be training them further, but I would really like the readers and enthusiasts to take this forward as I have made the code open-source.

Loss plot:

Loss Plot


Running the Code

Running the training is actually very simple. Just start the training by running the train.py script in the source/ directory. The test/ directory contains the unit tests if you would like to change anything about the implementation Refer to the following parameters for tweaking for your own use:

-h, --help            show this help message and exit
--generator_file GENERATOR_FILE
                    pretrained weights file for generator
--gen_optim_file GEN_OPTIM_FILE
                    previously saved state of Generator Optimizer
--discriminator_file DISCRIMINATOR_FILE
                    pretrained_weights file for discriminator
--dis_optim_file DIS_OPTIM_FILE
                    previously saved state of Generator Optimizer
--images_dir IMAGES_DIR
                    path for the images directory
--folder_distributed_dataset FOLDER_DISTRIBUTED_DATASET
                    path for the images directory
--sample_dir SAMPLE_DIR
                    path for the generated samples directory
--model_dir MODEL_DIR
                    path for saved models directory
--loss_function LOSS_FUNCTION
                    loss function to be used: 'hinge', 'relativistic-
                    hinge', 'standard-gan', 'standard-gan_with-sigmoid',
                    'wgan-gp', 'lsgan'
--size SIZE           Size of the generated image (must be a power of 2 and
                    >= 4)
--latent_distrib LATENT_DISTRIB
                    Type of latent distribution to be used 'uniform' or
                    'gaussian'
--latent_size LATENT_SIZE
                    latent size for the generator
--final_channels FINAL_CHANNELS
                    starting number of channels in the networks
--max_channels MAX_CHANNELS
                    maximum number of channels in the network
--init_beta INIT_BETA
                    initial value of beta
--i_c I_C             value of information bottleneck
--batch_size BATCH_SIZE
                    batch_size for training
--start START         starting epoch number
--num_epochs NUM_EPOCHS
                    number of epochs for training
--feedback_factor FEEDBACK_FACTOR
                    number of logs to generate per epoch
--num_samples NUM_SAMPLES
                    number of samples to generate for creating the grid
                    should be a square number preferably
--checkpoint_factor CHECKPOINT_FACTOR
                    save model per n epochs
--g_lr G_LR           learning rate for generator
--d_lr D_LR           learning rate for discriminator
--data_percentage DATA_PERCENTAGE
                    percentage of data to use
--num_workers NUM_WORKERS
                    number of parallel workers for reading files

Please Note that all the default values are tuned for the CelebA 128x128 experiment. Please refer to the paper for the CIFAR-10 and CelebA-HQ experiments.

Trained weights for generating cool faces / resuming the training :)

Please refer to the shared drive for the saved weights for this model in PyTorch format.

Other links

medium blog -> https://medium.com/@animeshsk3/v-gan-variational-discriminator-bottleneck-an-unfair-fight-between-generator-and-discriminator-972563532dcc
Generated samples video -> https://www.youtube.com/watch?v=-0lBw9z8Ds0
My slack group -> https://join.slack.com/t/amlrldl/shared_invite/enQtNDcyMTIxODg3NjIzLTA3MTlmMDg0YmExYjY5OTgyZTg4MTg5ZGE1YzRlYjljZmM4MzI0MTg1OTcxOTc5NDQ4ZTcwMGVkZjBjZmU5ZWM

Thanks

Please feel free to open Issues / PRs here

Cheers 🍻!
@akanimax :)

More Repositories

1

BMSG-GAN

[MSG-GAN] Any body can GAN! Highly stable and robust architecture. Requires little to no hyperparameter tuning. Pytorch Implementation
Python
630
star
2

T2F

T2F: text to face generation using Deep Learning
Python
546
star
3

pro_gan_pytorch

Unofficial PyTorch implementation of the paper titled "Progressive growing of GANs for improved Quality, Stability, and Variation"
Python
536
star
4

msg-stylegan-tf

MSG StyleGAN in tensorflow
Python
264
star
5

natural-language-summary-generation-from-structured-data

Implementation of the paper -> https://arxiv.org/abs/1709.00155. For converting information present in the form of structured data into natural language text
Python
183
star
6

msg-gan-v1

MSG-GAN: Multi-Scale Gradients GAN (Architecture inspired from ProGAN but doesn't use layer-wise growing)
Python
151
star
7

fagan

A variant of the Self Attention GAN named: FAGAN (Full Attention GAN)
Python
112
star
8

thr3ed_atom

ReLU Fields The Little Non-linearity That Could
Python
111
star
9

big-discriminator-batch-spoofing-gan

BMSG-GAN with more features
Python
45
star
10

pro_gan_pytorch-examples

Examples trained using the python pytorch package pro-gan-pth
Python
38
star
11

attn_gan_pytorch

python package for self-attention gan implemented as extension of PyTorch nn.Module. paper -> https://arxiv.org/abs/1805.08318
Python
19
star
12

my-deity

I worship the one true neural network architecture that can autonomously learn everything.
Jupyter Notebook
12
star
13

NLP2SQL

A research and review of techniques to provide a natural language interface to RDMS.
Jupyter Notebook
11
star
14

open-styleganv2-pytorch

Open source + Free for Commercial Use implementation of StyleGANv2 in pytorch
Python
7
star
15

capsule-network-TensorFlow

The impending concept of capsule networks has finally arrived at arXiv. link to the publication -> https://arxiv.org/abs/1710.09829 . In this repository, I'll create an implementation using TensorFlow from scratch as an exercise.
Jupyter Notebook
6
star
16

3inGAN

Python
5
star
17

GAN-understanding

Implements gans on toy datasets and preliminary ML datasets for showing certain aspects of convergence and stability. Tries to cover various loss functions defined over the years.
Jupyter Notebook
5
star
18

autoencoder-cifar-10

Implementing an auto-encoder for the cifar10 dataset
Jupyter Notebook
4
star
19

Homecoming

repository for mini-projects
Python
3
star
20

python_ai_project_template

A lightweight template for building AI-based prototype/research POCs in Python. My poison (DL framework :laugh: ) of choice is PyTorch!
Python
3
star
21

some-randon-gan-1

MSG-GAN with self attention. For MSG-GAN head to -> https://github.com/akanimax/MSG-GAN
Python
2
star
22

AI-Literature

A repository to store key research works from the past. It is also an attempt to structure and organize these research papers.
2
star
23

indian-celeb-gans

Various GANs trained on a dataset containing images of Indian Celebrities (procured by me).
Python
2
star
24

some-random-gan-2

More experimentation with the base MSG-GAN architecture. This includes the coord-conv layers in the architecture. For more info about MSG-GAN, head to -> https://github.com/akanimax/msg-stylegan-tf
Python
2
star
25

multithreaded-histogram-equalization-cpp

Explanatory Code for performing Histogram Equalization on Images for contrast improvement. The code uses OpenCV in C++ for image read/write and uses pthread for multithreading
C++
2
star
26

deep-reinforcement-learning

Project for studying and implementing the traditional RL algorithms and also the DL variants of the same.
Jupyter Notebook
1
star
27

dcgan_pytorch

GAN example created using the attn_gan_pytorch package -> https://github.com/akanimax/attn_gan_pytorch
Python
1
star
28

CL-3_lab_2017

repository for assignments of Computer Laboratory 3 - 2016
TeX
1
star
29

SVC2004-deep-learning

A deep learning based solution for the SVC2004 problem.
Jupyter Notebook
1
star
30

toxic-comment-identification-tensorflow

Data -> https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data
Python
1
star
31

REST_MOVIE_TICKET_SYSTEM

A restful system implementing token based authentication for allowing users to book movie tickets online. Use of Play framework for scala
Scala
1
star
32

algorithms

A repository for collecting the coding implementations of some of the most famous algorithms
Python
1
star
33

energy-preserving-neural-network

When a data signal propagates through the Neural Network, it is not mandatory that the energy of the signal will be preserved throughout the neural computations. This research attempts at collecting (perhaps creating) techniques for preserving the Energy throughout the network.
Python
1
star