• Stars
    star
    366
  • Rank 116,547 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created about 4 years ago
  • Updated almost 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

My implementation of various GAN (generative adversarial networks) architectures like vanilla GAN (Goodfellow et al.), cGAN (Mirza et al.), DCGAN (Radford et al.), etc.

PyTorch GANs 💻 vs 💻 = ❤️

This repo contains PyTorch implementation of various GAN architectures.
It's aimed at making it easy for beginners to start playing and learning about GANs.

All of the repos I found do obscure things like setting bias in some network layer to False without explaining
why certain design decisions were made. This repo makes every design decision transparent.

Table of Contents

What are GANs?

GANs were originally proposed by Ian Goodfellow et al. in a seminal paper called Generative Adversarial Nets.

GANs are a framework where 2 models (usually neural networks), called generator (G) and discriminator (D), play a minimax game against each other. The generator is trying to learn the distribution of real data and is the network which we're usually interested in. During the game the goal of the generator is to trick the discriminator into "thinking" that the data it generates is real. The goal of the discriminator, on the other hand, is to correctly discriminate between the generated (fake) images and real images coming from some dataset (e.g. MNIST).

Setup

  1. git clone https://github.com/gordicaleksa/pytorch-gans
  2. Open Anaconda console and navigate into project directory cd path_to_repo
  3. Run conda env create from project directory (this will create a brand new conda environment).
  4. Run activate pytorch-gans (for running scripts from your console or set the interpreter in your IDE)

That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies.


PyTorch package will pull some version of CUDA with it, but it is highly recommended that you install system-wide CUDA beforehand, mostly because of GPU drivers. I also recommend using Miniconda installer as a way to get conda on your system.

Follow through points 1 and 2 of this setup and use the most up-to-date versions of Miniconda and CUDA/cuDNN.

Implementations

Important note: you don't need to train the GANs to use this project I've checked-in pre-trained models.
You can just use the generate_imagery.py script to play with the models.

Vanilla GAN

Vanilla GAN is my implementation of the original GAN paper (Goodfellow et al.) with certain modifications mostly in the model architecture, like the usage of LeakyReLU and 1D batch normalization (it didn't even exist back then) instead of the maxout activation and dropout.

Examples

GAN was trained on data from MNIST dataset. Here is how the digits from the dataset look like:

You can see how the network is slowly learning to capture the data distribution during training:

After the generator is trained we can use it to generate all 10 digits! Looks like it's coming directly from MNIST, right!?

We can also pick 2 generated numbers that we like, save their latent vectors, and subsequently linearly or spherically
interpolate between them to generate new images and understand how the latent space (z-space) is structured:

We can see how the number 4 is slowly morphing into 9 and then into the number 3.

The idea behind spherical interpolation is super easy - instead of moving over the shortest possible path
(line i.e. linear interpolation) from the first vector (p0) to the second (p1), you take the sphere's arc path:

Usage

Option 1: Jupyter Notebook

Just run jupyter notebook from you Anaconda console and it will open the session in your default browser.
Open Vanilla GAN (PyTorch).py and you're ready to play!

If you created the env before I added jupyter just do pip install jupyter==1.0.0 and you're ready.


Note: if you get DLL load failed while importing win32api: The specified module could not be found
Just do pip uninstall pywin32 and then either pip install pywin32 or conda install pywin32 should fix it!

Option 2: Use your IDE of choice

Training

It's really easy to kick-off new training just run this:
python train_vanilla_gan.py --batch_size <number which won't break your GPU's VRAM>

The code is well commented so you can exactly understand how the training itself works.

The script will:

  • Dump checkpoint *.pth models into models/checkpoints/
  • Dump the final *.pth model into models/binaries/
  • Dump intermediate generated imagery into data/debug_imagery/
  • Download MNIST (~100 MB) the first time you run it and place it into data/MNIST/
  • Dump tensorboard data into runs/, just run tensorboard --logdir=runs from your Anaconda

And that's it you can track the training both visually (dumped imagery) and through G's and D's loss progress.

Tracking loss can be helpful but I mostly relied on visually analyzing intermediate imagery.

Note1: also make sure to check out playground.py file if you're having problems understanding adversarial loss.
Note2: Images are dumped both to the file system data/debug_imagery/ but also to tensorboard.

Generating imagery and interpolating

To generate a single image just run the script with defaults:
python generate_imagery.py

It will display and dump the generated image into data/generated_imagery/ using checked-in generator model.

Make sure to change the --model_name param to your model's name (once you train your own model).


If you want to play with interpolation, just set the --generation_mode to GenerationMode.INTERPOLATION.
And optionally set --slerp to true if you want to use spherical interpolation.

The first time you run it in this mode the script will start generating images,
and ask you to pick 2 images you like by entering 'y' into the console.

Finally it will start displaying interpolated imagery and dump the results to data/interpolated_imagery.

Conditional GAN

Conditional GAN (cGAN) is my implementation of the cGAN paper (Mehdi et al.).
It basically just adds conditioning vectors (one hot encoding of digit labels) to the vanilla GAN above.

Examples

In addition to everything that we could do with the original GAN, here we can exactly control which digit we want to generate! We make it dump 10x10 grid where each column is a single digit and this is how the learning proceeds:

Usage

For training just check out vanilla GAN (just make sure to use train_cgan.py instead).

Generating imagery

Same as for vanilla GAN but you can additionally set cgan_digit to a number between 0 and 9 to generate that exact digit! There is no interpolation support for cGAN, it's the same as for vanilla GAN feel free to use that.

Note: make sure to set --model_name to either CGAN_000000.pth (pre-trained and checked-in) or your own model.

DCGAN

DCGAN is my implementation of the DCGAN paper (Radford et al.).
The main contribution of the paper was that they were the first who made CNNs successfully work in the GAN setup.
Batch normalization was invented in the meanwhile and that's what got CNNs to work basically.

Examples

I trained DCGAN on preprocessed CelebA dataset. Here are some samples from the dataset:

Again, you can see how the network is slowly learning to capture the data distribution during training:

After the generator is trained we can use it to generate new faces! This problem is much harder than generating MNIST digits, so generated faces are not indistinguishable from the real ones.

Some SOTA GAN papers did a much better job at generating faces, currently the best model is StyleGAN2.

Similarly we can explore the structure of the latent space via interpolations:

We can see how the man's face is slowly morphing into woman's face and also the skin tan is changing gradually.

Finally, because the latent space has some nice properties (linear structure) we can do some interesting things.
Subtracting neutral woman's latent vector from smiling woman's latent vector gives us the "smile vector".
Adding that vector to neutral man's latent vector, we hopefully get smiling man's latent vector. And so it is!

You can also create the "sunglasses vector" and use it to add sunglasses to other faces, etc.

Note: I've created an interactive script so you can play with this check out GenerationMode.VECTOR_ARITHMETIC.

Usage

For training just check out vanilla GAN (just make sure to use train_dcgan.py instead).
The only difference is that this script will download pre-processed CelebA dataset instead of MNIST.

Generating imagery

Again just use the generate_imagery.py script.

You have 3 options you can set the generation_mode to:

  • GenerationMode.SINGLE_IMAGE <- generate a single face image
  • GenerationMode.INTERPOLATION <- pick 2 face images you like and script will interpolate between them
  • GenerationMode.VECTOR_ARITHMETIC <- pick 9 images and script will do vector arithmetic

GenerationMode.VECTOR_ARITHMETIC will give you an interactive matplotlib plot to pick 9 images.

Note: make sure to set --model_name to either DCGAN_000000.pth (pre-trained and checked-in) or your own model.

Acknowledgements

I found these repos useful (while developing this one):

Citation

If you find this code useful for your research, please cite the following:

@misc{Gordić2020PyTorchGANs,
  author = {Gordić, Aleksa},
  title = {pytorch-gans},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/gordicaleksa/pytorch-gans}},
}

Connect with me

If you'd love to have some more AI-related content in your life 🤓, consider:

Licence

License: MIT

More Repositories

1

pytorch-GAT

My implementation of the original GAT paper (Veličković et al.). I've additionally included the playground.py file for visualizing the Cora dataset, GAT embeddings, an attention mechanism, and entropy histograms. I've supported both Cora (transductive) and PPI (inductive) examples!
Jupyter Notebook
2,253
star
2

pytorch-original-transformer

My implementation of the original transformer model (Vaswani et al.). I've additionally included the playground.py file for visualizing otherwise seemingly hard concepts. Currently included IWSLT pretrained models.
Jupyter Notebook
880
star
3

get-started-with-JAX

The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
Jupyter Notebook
546
star
4

Open-NLLB

Effort to open-source NLLB checkpoints.
Python
364
star
5

pytorch-deepdream

PyTorch implementation of DeepDream algorithm (Mordvintsev et al.). Additionally I've included playground.py to help you better understand basic concepts behind the algo.
Jupyter Notebook
352
star
6

pytorch-neural-style-transfer

Reconstruction of the original paper on neural style transfer (Gatys et al.). I've additionally included reconstruction scripts which allow you to reconstruct only the content or the style of the image - for better understanding of how NST works.
Python
343
star
7

stable_diffusion_playground

Playing around with stable diffusion. Generated images are reproducible because I save the metadata and latent information. You can generate and then later interpolate between the images of your choice.
Python
203
star
8

pytorch-learn-reinforcement-learning

A collection of various RL algorithms like policy gradients, DQN and PPO. The goal of this repo will be to make it a go-to resource for learning about RL. How to visualize, debug and solve RL problems. I've additionally included playground.py for learning more about OpenAI gym, etc.
Python
140
star
9

pytorch-neural-style-transfer-johnson

Reconstruction of the fast neural style transfer (Johnson et al.). Some portions of the paper have been improved by the follow-up work like the instance normalization, etc. Checkout transformer_net.py's header for details.
Python
110
star
10

serbian-llm-eval

Serbian LLM Eval.
Python
81
star
11

pytorch-naive-video-neural-style-transfer

Create naive (no temporal loss) NST for videos with person segmentation. Just place your videos in data/, run and you get your stylized and segmented videos.
Python
73
star
12

OpenGemini

Effort to open-source 10.5 trillion parameter Gemini model.
17
star
13

gordicaleksa

GitHub's new feature: repo with the same name as your GitHub name initialized with README.md will show on your landing page!
12
star
14

digital-image-processing

Projects I did for the Digital Image Processing course on my university
MATLAB
7
star
15

streamlit_playground

Simple Streamlit app.
Python
4
star
16

Open-NLLB-stopes

A library for preparing data for machine translation research (monolingual preprocessing, bitext mining, etc.) for the Open-NLLB effort.
Python
3
star
17

MachineLearningMicrosoftPetnica

Problems I solved for Microsoft ML summer camp in Petnica, Serbia
C++
3
star
18

competitive_programming

Contains algorithms and snippets I found useful when solving problems for TopCoder, Google Code Jam etc.
C++
2
star
19

slovenian-llm-eval

Slovenian LLM Eval.
Python
2
star
20

MicrosoftBubbleCup2018

My solutions for Bubble Cup 2018
C++
1
star
21

.dotfiles

Configuration files for my vim editor, bash etc.
Shell
1
star
22

GoogleCodeJam2018

My solutions for Google Code Jam 2018
C++
1
star