• Stars
    star
    264
  • Rank 155,103 (Top 4 %)
  • Language
    Python
  • License
    Other
  • Created about 5 years ago
  • Updated almost 5 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

MSG StyleGAN in tensorflow

MSG-STYLEGAN-TF

Official code repository for the paper "MSG-GAN: Multi-Scale Gradients for Generative Adversarial Networks" [arXiv]

Teaser Diagram

Why this repository?

Our previous research work released the BMSG-GAN code in PyTorch which applied our proposed multi-scale connections in the basic ProGAN architecture (i.e. DCGAN architecture) instead of using the progressive growing. This repository applies the Multi-scale Gradient connections in StyleGAN replacing the progressive growing used for training original StyleGAN. The switch to Tensorflow was primarily to ensure an apples-to-apples comparison with StyleGAN.

Due Credit

This code heavily uses NVIDIA's original StyleGAN code. We accredit and acknowledge their work here. The Original License is located in the base directory (file named LICENSE_ORIGINAL.txt).

Abstract

While Generative Adversarial Networks (GANs) have seen huge successes in image synthesis tasks, they are notoriously difficult to adapt to different datasets, in part due to instability during training and sensitivity to hyperparameters. One commonly accepted reason for this instability is that gradients passing from the discriminator to the generator become uninformative when there isnโ€™t enough overlap in the supports of the real and fake distributions. In this work, we propose the Multi-Scale Gradient Generative Adversarial Network (MSG-GAN), a simple but effective technique for addressing this by allowing the flow of gradients from the discriminator to the generator at multiple scales. This technique provides a stable approach for high resolution image synthesis, and serves as an alternative to the commonly used progressive growing technique. We show that MSG-GAN converges stably on a variety of image datasets of different sizes, resolutions and domains, as well as different types of loss functions and architectures, all with the same set of fixed hyperparameters. When compared to state-of-the-art GANs, our approach matches or exceeds the performance in most of the cases we tried.

Method overview

Architecture diagram

Architecture of MSG-GAN, shown here on the base model proposed in ProGANs. Our architecture includes connections from the intermediate layers of the generator to the intermediate layers of the discriminator. Multi-scale images sent to the discriminator are concatenated with the corresponding activation volumes obtained from the main path of convolutional layers followed by a combine function (shown in yellow).

StyleGAN Modifications:

The MSG-StyleGAN model (in this repository) uses all the modifications proposed by StyleGAN to the ProGANs architecture except the mixing regularization. Similar to MSG-ProGAN (diagram above), we use a 1 x 1 conv layer to obtain the RGB images output from every block of the StyleGAN generator leaving everything else (mapping network, non-traditional input and style adaIN) untouched. The discriminator architecture is same as the ProGANs (and consequently MSG-ProGAN) discriminator.

System requirements

The code was built and tested for:

  • 64-bit Python 3.6.7
  • TensorFlow 1.13.1 with GPU support.
  • NVIDIA GPUs with at least 16GB of DRAM. We used variants of the Tesla V100 GPUs.
  • NVIDIA driver 418.56, CUDA toolkit 10.1, cuDNN 7.3.1.

How to run the code (Training)

Training can be run in the following 3 steps:

Step 1: Data formatting

The MSG-StyleGAN training pipeline expects the dataset to be in tfrecord format. This sped up the training to a great extent. Use the dataset_tool.py tool to generate these tfrecords from your raw dataset. In order to use the tool, either select from the bunch of datasets that it already provides or use the create_from_images option if you have a new dataset in the form of images. For full options and more information run:

(your_virtual_env)$ python dataset_tool.py --help
Step 2: Run the training script

First step is to update the paths in the global configuration located in config.py. For instance:

"""Global configuration."""

# ----------------------------------------------------------------------------
# Paths.
    
result_dir = "/home/karnewar/self_research/msg-stylegan/"
data_dir = "/media/datasets_external/"
cache_dir = "/home/karnewar/self_research/msg-stylegan/cache"
run_dir_ignore = ["results", "datasets", "cache"]

# ----------------------------------------------------------------------------

The result_dir is where all the trained models, training logs and evaluation score logs will be reported. The data_dir should contain the different datasets used for training under separate subdirectories, while the cache_dir stores any repeatedly required objects in the training. For instance the Mean and Std of the real images while calculating the FID.

Following this, download the inception net weights from here and place them in result_dir + "/inception_network/inception_v3_features.pkl".

Finally, modify the configurations in the train.py as per your situation and start training by just running the train.py script.

(your_virtual_env)$ python train.py

Pretrained models

Dataset Size Resolution GPUs used FID score Link
LSUN Churches ~150K 256 x 256 8 V100-16GB 5.20 drive link
Oxford Flowers ~8K 256 x 256 2 V100-32GB 19.60 drive link
Indian Celebs ~3K 256 x 256 4 V100-32GB 28.44 drive link
CelebA-HQ 30K 1024 x 1024 8 V100-16GB 6.37 drive link
FFHQ 70K 1024 x 1024 4 V100-32GB 5.80 drive link

How to use pretrained models

We provide three scripts generate_multiscale_samples.py, generate_samples.py and latent_space_interpolation_video.py which can be used to generate multi-scale generated images grids, highest resolution samples and latent space interpolation video respectively. Please see the below example.

(your_virtual_env)$ python latent_space_interpolation_video.py \
--pickle_file /home/karnewar/msg-stylegan/00004-msg-stylegan-visual_art-4gpu/best_model.pkl \
--output_file /home/karnewar/msg-stylegan/visual_art_interpolation_hd.avi \
--num_points 30 \
--transition_points 30 \
--resize 800 1920 \

How to run evaluation

The training pipeline already computes the metric during training along with a tensorboard log. But, in case you wish to evaluate the trained models again for research baseline (or for some other reason, say fast training and separate evaluation), Please use the run_metrics.py script. Modify the following lines according to your situation:

tasks = []
tasks += [
    EasyDict(
        run_func_name="run_metrics.run_pickle",
        network_pkl="/home/karnewar/msg-stylegan/00002-msg-stylegan-indian_celebs-4gpu/network-snapshot.pkl",
        dataset_args=EasyDict(tfrecord_dir="indian_celebs/tfrecords", shuffle_mb=0),
        mirror_augment=True,
    )
]  
# tasks += [EasyDict(run_func_name='run_metrics.run_snapshot', run_id=100, snapshot=25000)]
# tasks += [EasyDict(run_func_name='run_metrics.run_all_snapshots', run_id=100)]

# How many GPUs to use?
submit_config.num_gpus = 1
# submit_config.num_gpus = 2
# submit_config.num_gpus = 4
# submit_config.num_gpus = 8

and run:

(your_virtual_env)$ python run_metrics.py

The run_snapshot and run_pickle do practically the same thing with the minor exception of the usage. The former needs a snapshot_id and run_id and the files are located automatically, whereas the latter needs the pickle file to be provided. I personally find the run_pickle much more useful. The run_all_snapshots function takes the run_id and evaluates all snapshots located in that run_dir.

Stability and Ease of Use :)

Usually, it is the case that stability and easy usage are not the terms that you'd use in the context of a GAN ๐Ÿ˜†. But with the multi-scale gradients in the GAN, the training is quite stable. We show a juxtaposing experiment for this as follows:

progan_stability msggan_stability

We quantify the image stability during training. These plots show the MSE between images generated from the same latent code at the beginning of sequential epochs (averaged over 36 latent samples) on the CelebA-HQ dataset. MSG-GAN converges stably over time while Progressive Growing continues to vary significantly across epochs. Please note that the first half of the epochs are spent in fading in the new layer, but apparently, even for the subsequent epochs, the changes made are quite significant.

training_explanation

During training, all the layers in the MSG-GAN synchronize across the generated resolutions fairly early in the training and subsequently improve the quality of the generated images at all scales simultaneously. Throughout the training the generator makes only minimal incremental improvements to the images generated from fixed latent points.

Qualitative examples

Simultaneous multi-scale latent space interpolation FFHQ [1024 x 1024]
ffhq_multi_scale

CelebA-HQ [1024 x 1024]
CelebA-HQ

FFHQ [1024 x 1024]
FFHQ

LSUN Churches [256 x 256]
LSUN Churches

Oxford Flowers [256 x 256]
Oxford Flowers

Indian Celebs [256 x 256]
Indian Celebs

More Full resolution CelebA-HQ samples [1024x 1024]
full_res_hq_sheet

Cite our work

@article{karnewar2019msg,
  title={MSG-GAN: Multi-Scale Gradients for Generative Adversarial Networks},
  author={Karnewar, Animesh and Wang, Oliver},
  journal={arXiv preprint arXiv:1903.06048},
  year={2019}
}

Other contributors

Please feel free to open PRs here if you train on other datasets using this architecture.

Thanks and regards

[โญ New โญ] Please check out my new IG handle @the_GANista. I will be posting fun GAN based visual art here. :).

Thank you all for supporting and encouraging my work. I hope this will be useful for your research / project / work.

As always, any suggestion / feedback / contribution is always welcome ๐Ÿ˜„.

Cheers ๐Ÿป!
@akanimax :)

More Repositories

1

BMSG-GAN

[MSG-GAN] Any body can GAN! Highly stable and robust architecture. Requires little to no hyperparameter tuning. Pytorch Implementation
Python
630
star
2

T2F

T2F: text to face generation using Deep Learning
Python
546
star
3

pro_gan_pytorch

Unofficial PyTorch implementation of the paper titled "Progressive growing of GANs for improved Quality, Stability, and Variation"
Python
536
star
4

natural-language-summary-generation-from-structured-data

Implementation of the paper -> https://arxiv.org/abs/1709.00155. For converting information present in the form of structured data into natural language text
Python
183
star
5

Variational_Discriminator_Bottleneck

Implementation (with some experimentation) of the paper titled "VARIATIONAL DISCRIMINATOR BOTTLENECK: IMPROVING IMITATION LEARNING, INVERSE RL, AND GANS BY CONSTRAINING INFORMATION FLOW" (arxiv -> https://arxiv.org/pdf/1810.00821.pdf)
Python
152
star
6

msg-gan-v1

MSG-GAN: Multi-Scale Gradients GAN (Architecture inspired from ProGAN but doesn't use layer-wise growing)
Python
151
star
7

fagan

A variant of the Self Attention GAN named: FAGAN (Full Attention GAN)
Python
112
star
8

thr3ed_atom

ReLU Fields The Little Non-linearity That Could
Python
111
star
9

big-discriminator-batch-spoofing-gan

BMSG-GAN with more features
Python
45
star
10

pro_gan_pytorch-examples

Examples trained using the python pytorch package pro-gan-pth
Python
38
star
11

attn_gan_pytorch

python package for self-attention gan implemented as extension of PyTorch nn.Module. paper -> https://arxiv.org/abs/1805.08318
Python
19
star
12

my-deity

I worship the one true neural network architecture that can autonomously learn everything.
Jupyter Notebook
12
star
13

NLP2SQL

A research and review of techniques to provide a natural language interface to RDMS.
Jupyter Notebook
11
star
14

open-styleganv2-pytorch

Open source + Free for Commercial Use implementation of StyleGANv2 in pytorch
Python
7
star
15

capsule-network-TensorFlow

The impending concept of capsule networks has finally arrived at arXiv. link to the publication -> https://arxiv.org/abs/1710.09829 . In this repository, I'll create an implementation using TensorFlow from scratch as an exercise.
Jupyter Notebook
6
star
16

3inGAN

Python
5
star
17

GAN-understanding

Implements gans on toy datasets and preliminary ML datasets for showing certain aspects of convergence and stability. Tries to cover various loss functions defined over the years.
Jupyter Notebook
5
star
18

autoencoder-cifar-10

Implementing an auto-encoder for the cifar10 dataset
Jupyter Notebook
4
star
19

Homecoming

repository for mini-projects
Python
3
star
20

python_ai_project_template

A lightweight template for building AI-based prototype/research POCs in Python. My poison (DL framework :laugh: ) of choice is PyTorch!
Python
3
star
21

some-randon-gan-1

MSG-GAN with self attention. For MSG-GAN head to -> https://github.com/akanimax/MSG-GAN
Python
2
star
22

AI-Literature

A repository to store key research works from the past. It is also an attempt to structure and organize these research papers.
2
star
23

indian-celeb-gans

Various GANs trained on a dataset containing images of Indian Celebrities (procured by me).
Python
2
star
24

some-random-gan-2

More experimentation with the base MSG-GAN architecture. This includes the coord-conv layers in the architecture. For more info about MSG-GAN, head to -> https://github.com/akanimax/msg-stylegan-tf
Python
2
star
25

multithreaded-histogram-equalization-cpp

Explanatory Code for performing Histogram Equalization on Images for contrast improvement. The code uses OpenCV in C++ for image read/write and uses pthread for multithreading
C++
2
star
26

deep-reinforcement-learning

Project for studying and implementing the traditional RL algorithms and also the DL variants of the same.
Jupyter Notebook
1
star
27

dcgan_pytorch

GAN example created using the attn_gan_pytorch package -> https://github.com/akanimax/attn_gan_pytorch
Python
1
star
28

CL-3_lab_2017

repository for assignments of Computer Laboratory 3 - 2016
TeX
1
star
29

SVC2004-deep-learning

A deep learning based solution for the SVC2004 problem.
Jupyter Notebook
1
star
30

toxic-comment-identification-tensorflow

Data -> https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data
Python
1
star
31

REST_MOVIE_TICKET_SYSTEM

A restful system implementing token based authentication for allowing users to book movie tickets online. Use of Play framework for scala
Scala
1
star
32

algorithms

A repository for collecting the coding implementations of some of the most famous algorithms
Python
1
star
33

energy-preserving-neural-network

When a data signal propagates through the Neural Network, it is not mandatory that the energy of the signal will be preserved throughout the neural computations. This research attempts at collecting (perhaps creating) techniques for preserving the Energy throughout the network.
Python
1
star