• Stars
    star
    630
  • Rank 71,328 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 6 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[MSG-GAN] Any body can GAN! Highly stable and robust architecture. Requires little to no hyperparameter tuning. Pytorch Implementation

BMSG-GAN

PyTorch implementation of [MSG-GAN].

**Please note that this is not the repo for the MSG-GAN research paper. Please head over to the msg-stylegan-tf repository for the official code and trained models for the MSG-GAN paper.

SageMaker

Training is now supported on AWS SageMaker. Please read https://docs.aws.amazon.com/sagemaker/latest/dg/pytorch.html

Flagship Diagram

MSG-GAN: Multi-Scale Gradient GAN for Stable Image Synthesis

Abstract:
While Generative Adversarial Networks (GANs) have seen huge successes in image synthesis tasks, they are notoriously difficult to use, in part due to instability during training. One commonly accepted reason for this instability is that gradients passing from the discriminator to the generator can quickly become uninformative, due to a learning imbalance during training. In this work, we propose the Multi-Scale Gradient Generative Adversarial Network (MSG-GAN), a simple but effective technique for addressing this problem which allows the flow of gradients from the discriminator to the generator at multiple scales. This technique provides a stable approach for generating synchronized multi-scale images. We present a very intuitive implementation of the mathematical MSG-GAN framework which uses the concatenation operation in the discriminator computations. We empirically validate the effect of our MSG-GAN approach through experiments on the CIFAR10 and Oxford102 flowers datasets and compare it with other relevant techniques which perform multi-scale image synthesis. In addition, we also provide details of our experiment on CelebA-HQ dataset for synthesizing 1024 x 1024 high resolution images.

Training time-lapse gif

An explanatory training time-lapse video/gif for the MSG-GAN. The higher resolution layers initially display plain colour blocks but eventually (very soon) the training penetrates all layers and then they all work in unison to produce better samples. Please observe the first few secs of the training, where the face like blobs appear in a sequential order from the lowest resolution to the highest resolution.

Multi-Scale Gradients architecture

proposed MSG-GAN architecture

The above figure describes the architecture of MSG-GAN for generating synchronized multi-scale images. Our method is based on the architecture proposed in proGAN, but instead of a progressively growing training scheme, includes connections from the intermediate layers of the generator to the intermediate layers of the discriminator. The multi-scale images input to the discriminator are converted into spatial volumes which are concatenated with the corresponding activation volumes obtained from the main path of convolutional layers.


For the discrimination process, appropriately downsampled versions of the real images are fed to corresponding layers of the discriminator as shown in the diagram (from above).


synchronization explanation


Above figure explains how, during training, all the layers in the MSG-GAN first synchronize colour-wise and subsequently improve the generated images at various scales. The brightness of the images across all layers (scales) synchronizes eventually

Running the Code

Please note to use value of learning_rate=0.003 for both G and D for all experiments for best results. The model is quite robust and converges to a very similar FID or IS very quickly even for different learning rate settings. Please use the relativistic-hinge as the loss function (set as default) for training.

Start the training by running the train.py script in the sourcecode/ directory. Refer to the following parameters for tweaking for your own use:

-h, --help            show this help message and exit
  --generator_file GENERATOR_FILE
                        pretrained weights file for generator
  --generator_optim_file GENERATOR_OPTIM_FILE
                        saved state for generator optimizer
  --shadow_generator_file SHADOW_GENERATOR_FILE
                        pretrained weights file for the shadow generator
  --discriminator_file DISCRIMINATOR_FILE
                        pretrained_weights file for discriminator
  --discriminator_optim_file DISCRIMINATOR_OPTIM_FILE
                        saved state for discriminator optimizer
  --images_dir IMAGES_DIR
                        path for the images directory
  --folder_distributed FOLDER_DISTRIBUTED
                        whether the images directory contains folders or not
  --flip_augment FLIP_AUGMENT
                        whether to randomly mirror the images during training
  --sample_dir SAMPLE_DIR
                        path for the generated samples directory
  --model_dir MODEL_DIR
                        path for saved models directory
  --loss_function LOSS_FUNCTION
                        loss function to be used: standard-gan, wgan-gp,
                        lsgan,lsgan-sigmoid,hinge, relativistic-hinge
  --depth DEPTH         Depth of the GAN
  --latent_size LATENT_SIZE
                        latent size for the generator
  --batch_size BATCH_SIZE
                        batch_size for training
  --start START         starting epoch number
  --num_epochs NUM_EPOCHS
                        number of epochs for training
  --feedback_factor FEEDBACK_FACTOR
                        number of logs to generate per epoch
  --num_samples NUM_SAMPLES
                        number of samples to generate for creating the grid
                        should be a square number preferably
  --checkpoint_factor CHECKPOINT_FACTOR
                        save model per n epochs
  --g_lr G_LR           learning rate for generator
  --d_lr D_LR           learning rate for discriminator
  --adam_beta1 ADAM_BETA1
                        value of beta_1 for adam optimizer
  --adam_beta2 ADAM_BETA2
                        value of beta_2 for adam optimizer
  --use_eql USE_EQL     Whether to use equalized learning rate or not
  --use_ema USE_EMA     Whether to use exponential moving averages or not
  --ema_decay EMA_DECAY
                        decay value for the ema
  --data_percentage DATA_PERCENTAGE
                        percentage of data to use
  --num_workers NUM_WORKERS
                        number of parallel workers for reading files
Sample Training Run

For training a network at resolution 256 x 256, use the following arguments:

$ python train.py --depth=7 \ 
                  --latent_size=512 \
                  --images_dir=<path to images> \
                  --sample_dir=samples/exp_1 \
                  --model_dir=models/exp_1

Set the batch_size, feedback_factor and checkpoint_factor accordingly. We used 2 Tesla V100 GPUs of the DGX-1 machine for our experimentation.

Generated samples on different datasets

[NEW] CelebA HQ [1024 x 1024] (30K dataset)
CelebA-HQ


[NEW] Oxford Flowers (improved samples) [256 x 256] (8K dataset)
oxford_big oxford_variety


CelebA HQ [256 x 256] (30K dataset)
CelebA-HQ


LSUN Bedrooms [128 x 128] (3M dataset)
lsun_bedrooms


CelebA [128 x 128] (200K dataset)
CelebA


Synchronized all-res generated samples

Cifar-10 [32 x 32] (50K dataset)
cifar_allres


Oxford-102 Flowers [256 x 256] (8K dataset)
flowers_allres


Cite our work

@article{karnewar2019msg,
  title={MSG-GAN: Multi-Scale Gradient GAN for Stable Image Synthesis},
  author={Karnewar, Animesh and Wang, Oliver and Iyengar, Raghu Sesha},
  journal={arXiv preprint arXiv:1903.06048},
  year={2019}
}

Other Contributors 😄

Cartoon Set [128 x 128] (10K dataset) by @huangzh13
Cartoon_Set


Thanks

Please feel free to open PRs here if you train on other datasets using this architecture.

Best regards,
@akanimax :)

More Repositories

1

T2F

T2F: text to face generation using Deep Learning
Python
546
star
2

pro_gan_pytorch

Unofficial PyTorch implementation of the paper titled "Progressive growing of GANs for improved Quality, Stability, and Variation"
Python
536
star
3

msg-stylegan-tf

MSG StyleGAN in tensorflow
Python
264
star
4

natural-language-summary-generation-from-structured-data

Implementation of the paper -> https://arxiv.org/abs/1709.00155. For converting information present in the form of structured data into natural language text
Python
183
star
5

Variational_Discriminator_Bottleneck

Implementation (with some experimentation) of the paper titled "VARIATIONAL DISCRIMINATOR BOTTLENECK: IMPROVING IMITATION LEARNING, INVERSE RL, AND GANS BY CONSTRAINING INFORMATION FLOW" (arxiv -> https://arxiv.org/pdf/1810.00821.pdf)
Python
152
star
6

msg-gan-v1

MSG-GAN: Multi-Scale Gradients GAN (Architecture inspired from ProGAN but doesn't use layer-wise growing)
Python
151
star
7

fagan

A variant of the Self Attention GAN named: FAGAN (Full Attention GAN)
Python
112
star
8

thr3ed_atom

ReLU Fields The Little Non-linearity That Could
Python
111
star
9

big-discriminator-batch-spoofing-gan

BMSG-GAN with more features
Python
45
star
10

pro_gan_pytorch-examples

Examples trained using the python pytorch package pro-gan-pth
Python
38
star
11

attn_gan_pytorch

python package for self-attention gan implemented as extension of PyTorch nn.Module. paper -> https://arxiv.org/abs/1805.08318
Python
19
star
12

my-deity

I worship the one true neural network architecture that can autonomously learn everything.
Jupyter Notebook
12
star
13

NLP2SQL

A research and review of techniques to provide a natural language interface to RDMS.
Jupyter Notebook
11
star
14

open-styleganv2-pytorch

Open source + Free for Commercial Use implementation of StyleGANv2 in pytorch
Python
7
star
15

capsule-network-TensorFlow

The impending concept of capsule networks has finally arrived at arXiv. link to the publication -> https://arxiv.org/abs/1710.09829 . In this repository, I'll create an implementation using TensorFlow from scratch as an exercise.
Jupyter Notebook
6
star
16

3inGAN

Python
5
star
17

GAN-understanding

Implements gans on toy datasets and preliminary ML datasets for showing certain aspects of convergence and stability. Tries to cover various loss functions defined over the years.
Jupyter Notebook
5
star
18

autoencoder-cifar-10

Implementing an auto-encoder for the cifar10 dataset
Jupyter Notebook
4
star
19

Homecoming

repository for mini-projects
Python
3
star
20

python_ai_project_template

A lightweight template for building AI-based prototype/research POCs in Python. My poison (DL framework :laugh: ) of choice is PyTorch!
Python
3
star
21

some-randon-gan-1

MSG-GAN with self attention. For MSG-GAN head to -> https://github.com/akanimax/MSG-GAN
Python
2
star
22

AI-Literature

A repository to store key research works from the past. It is also an attempt to structure and organize these research papers.
2
star
23

indian-celeb-gans

Various GANs trained on a dataset containing images of Indian Celebrities (procured by me).
Python
2
star
24

some-random-gan-2

More experimentation with the base MSG-GAN architecture. This includes the coord-conv layers in the architecture. For more info about MSG-GAN, head to -> https://github.com/akanimax/msg-stylegan-tf
Python
2
star
25

multithreaded-histogram-equalization-cpp

Explanatory Code for performing Histogram Equalization on Images for contrast improvement. The code uses OpenCV in C++ for image read/write and uses pthread for multithreading
C++
2
star
26

deep-reinforcement-learning

Project for studying and implementing the traditional RL algorithms and also the DL variants of the same.
Jupyter Notebook
1
star
27

dcgan_pytorch

GAN example created using the attn_gan_pytorch package -> https://github.com/akanimax/attn_gan_pytorch
Python
1
star
28

CL-3_lab_2017

repository for assignments of Computer Laboratory 3 - 2016
TeX
1
star
29

SVC2004-deep-learning

A deep learning based solution for the SVC2004 problem.
Jupyter Notebook
1
star
30

toxic-comment-identification-tensorflow

Data -> https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data
Python
1
star
31

REST_MOVIE_TICKET_SYSTEM

A restful system implementing token based authentication for allowing users to book movie tickets online. Use of Play framework for scala
Scala
1
star
32

algorithms

A repository for collecting the coding implementations of some of the most famous algorithms
Python
1
star
33

energy-preserving-neural-network

When a data signal propagates through the Neural Network, it is not mandatory that the energy of the signal will be preserved throughout the neural computations. This research attempts at collecting (perhaps creating) techniques for preserving the Energy throughout the network.
Python
1
star