• Stars
    star
    110
  • Rank 316,770 (Top 7 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated about 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Reconstruction of the fast neural style transfer (Johnson et al.). Some portions of the paper have been improved by the follow-up work like the instance normalization, etc. Checkout transformer_net.py's header for details.

Fast Neural Style Transfer (feed-forward method) 💻 + 🎨 = ❤️

This repo contains a concise PyTorch implementation of the original feed-forward NST paper (🔗 Johnson et al.).

Checkout my implementation of the original NST (optimization method) paper (Gatys et al.).

It's an accompanying repo for this video series on YouTube.

NST Intro

Why yet another Fast NST (feed-forward method) repo?

It's the cleanest and most concise NST repo that I know of + it's written in PyTorch! ❤️

My idea 💡 is to make the code so simple and well commented, that you can use it as a first step on your NST learning journey before any other blog, course, book or research paper. 📚

I've included automatic, pretrained models and MS COCO dataset, download script - so you can either instantaneously run it and get the results (🎨 stylized images) using pretrained models or start training/experimenting with your own models. 🚀

Examples

Here are some examples with the 4 pretrained models (automatic download enabled - look at usage section):

Note: keep in mind that I still need to improve these models, 3 of these (last 3 rows) only saw 33k images from MS COCO.

Setup

  1. Open Anaconda Prompt and navigate into project directory cd path_to_repo
  2. Run conda env create from project directory (this will create a brand new conda environment).
  3. Run activate pytorch-nst-fast (if you want to run scripts from your console otherwise set the interpreter in your IDE)

That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies.


PyTorch package will pull some version of CUDA with it, but it is highly recommended that you install system-wide CUDA beforehand, mostly because of GPU drivers. I also recommend using Miniconda installer as a way to get conda on your system.

Follow through points 1 and 2 of this setup and use the most up-to-date versions of Miniconda and CUDA/cuDNN (I recommend CUDA 10.1 or 10.2 as those are compatible with PyTorch 1.5, which is used in this repo, and newest compatible cuDNN).

Usage

Go through this section to run the project but if you are still having problems take a look at this (stylization) and this (training) accompanying YouTube videos.

Stylization

  1. Download pretrained models, run: python utils/resource_downloader.py
  2. Run python stylization_script.py (it's got default content image and model set)

That's it! If you want more flexibility (and I guess you do) there's a couple more nuggets of info.

More expressive command is:
python stylization_script.py --content_input <imgname or directory> --img_width <width> --model_name <name>

If you pass a directory into --content_input it will perform batch stylization.
You can control the batch size (in case you have VRAM problems) with batch_size param.


You just need to specify the names, the repo automatically finds content images and models in default directories:

  1. content images default dir: /data/content-images/
  2. model binaries default dir: /models/binaries/

So all you got to do is place images and models there and you can use them. Output will be dumped to /data/output-images/.

After you run resource_downloader.py script binaries dir will be pre-populated with 4 pretrained models.

Go ahead, play with it and make some art!

Training your own models

  1. Download MS COCO dataset, run python utils/resource_downloader.py -r mscoco_dataset (it's a 12.5 GB file)
  2. Run python training_script.py --style_img_name <name>

Now that will probably actually work!

It will periodically dump checkpoint models to /models/checkpoints/ and the final model to /models/binaries/ by default.

I strongly recommend playing with these 2 params:

  1. style_weight - I always kept it in the [1e5, 9e5] range, you may have to tweak it for your specific style image a little bit
  2. subset_size - Usually 32k images do the job (that's 8k batches) - you'll need to monitor tensorboard to figure out if your curves are saturating at that point or not. If they are still going down set the number higher

That brings us to the next section!

Tensorboard Visualizations

To start tensorboard just run: tensorboard --logdir=runs --samples_per_plugin images=50 from your conda console.

samples_per_plugin images=<number> sets the number of images you'll be able to see when moving the image slider.

There's basically 2 things you want to monitor during your training (not counting console output <- redundant if you use tensor board)

Monitor your loss/statistics curves

You want to keep content-loss and style-loss going down or at least one of them (style loss usually saturates first).

I usually set tv weight to 0 so that's why you see 0 on the tv-loss curve. You should use it only if you see that your images are having smoothness problem (check this out for visualization of what exactly tv weight does).

Statistics curves let me understand how the stylized image coming out of the transformer net behaves.

If max or min intensities start diverging or mean/median start going to far away from 0 that's a good indicator that your (probably) style weight is not good. You can keep the content weight constant and just tweak the style weight.

Monitor your intermediate stylized images

This one helps immensely so as to help you manually early-stop your training if you don't like the stylized output you see.

In the beggining stylized images look kinda rubish like the one one the left. As the training progresses you'll get more meaningful images popping out (the one on the right).

Debugging

Q: My style/content loss curves just spiked in the middle of training?
A: 2 options: a) rerun the training (optimizer got into a bad state) b) if that doesn't work lower your style weight

Q: How can I see the exact parameters that you used to train your models?
A: Just run the model in the stylization_script.py, training metadata will be printed out to the output console.

Further experimentation (advanced, for researchers)

There's a couple of things you could experiment with (assuming fixed net architectures), here are some ideas:

  1. Try and set MSE to sum reduction for the style loss. I used that method here and it gave nice results. You'll have to play with style-weight afterwards to get it running. This will effectively give bigger weight to deeper style representations because Gram matrices coming out from deeper layers are bigger. Meaning you'll give advantage to high-level style features (broad spatial characteristics of the style) over low level style features (smaller neighborhood characteristics like fine brush-strokes).
  2. Original paper used tanh activation at the output - figure out how you can get it to work using that, you may have to add some scaling. There is this magic constant 150 that Johnson originally used to scale tanh. I created this issue as it is un-clear of how it came to be and whether it was just experimentally figured out.
  3. PyTorch VGG16 pretrained model was trained on the 0..1 range ImageNet normalized images. Try and work with 0..255 range ImageNet mean-only-normalized images - that will also work! It worked here and if you try and feed such an image into VGG16 (as a classifier) it will give you correct predictions!
  4. This repo used 0..255 images (no normalization) as an input to transformer net - play with that. You'll have to normalize transformer net output before feeding that to VGG16.

Some of these may further improve the visual quality that you get from these models! If you find something interesting I'd like to hear from you!

Acknowledgements

I found these repos useful: (while developing this one)

I found some of the content/style images I was using here:

Other images are now already classics in the NST world.

Citation

If you find this code useful for your research, please cite the following:

@misc{Gordić2020nst-fast,
  author = {Gordić, Aleksa},
  title = {pytorch-nst-feedforward},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/gordicaleksa/pytorch-nst-feedforward}},
}

Connect with me

If you'd love to have some more AI-related content in your life 🤓, consider:

Licence

License: MIT

More Repositories

1

pytorch-GAT

My implementation of the original GAT paper (Veličković et al.). I've additionally included the playground.py file for visualizing the Cora dataset, GAT embeddings, an attention mechanism, and entropy histograms. I've supported both Cora (transductive) and PPI (inductive) examples!
Jupyter Notebook
2,253
star
2

pytorch-original-transformer

My implementation of the original transformer model (Vaswani et al.). I've additionally included the playground.py file for visualizing otherwise seemingly hard concepts. Currently included IWSLT pretrained models.
Jupyter Notebook
880
star
3

get-started-with-JAX

The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
Jupyter Notebook
546
star
4

pytorch-GANs

My implementation of various GAN (generative adversarial networks) architectures like vanilla GAN (Goodfellow et al.), cGAN (Mirza et al.), DCGAN (Radford et al.), etc.
Python
366
star
5

Open-NLLB

Effort to open-source NLLB checkpoints.
Python
364
star
6

pytorch-deepdream

PyTorch implementation of DeepDream algorithm (Mordvintsev et al.). Additionally I've included playground.py to help you better understand basic concepts behind the algo.
Jupyter Notebook
352
star
7

pytorch-neural-style-transfer

Reconstruction of the original paper on neural style transfer (Gatys et al.). I've additionally included reconstruction scripts which allow you to reconstruct only the content or the style of the image - for better understanding of how NST works.
Python
343
star
8

stable_diffusion_playground

Playing around with stable diffusion. Generated images are reproducible because I save the metadata and latent information. You can generate and then later interpolate between the images of your choice.
Python
203
star
9

pytorch-learn-reinforcement-learning

A collection of various RL algorithms like policy gradients, DQN and PPO. The goal of this repo will be to make it a go-to resource for learning about RL. How to visualize, debug and solve RL problems. I've additionally included playground.py for learning more about OpenAI gym, etc.
Python
140
star
10

serbian-llm-eval

Serbian LLM Eval.
Python
81
star
11

pytorch-naive-video-neural-style-transfer

Create naive (no temporal loss) NST for videos with person segmentation. Just place your videos in data/, run and you get your stylized and segmented videos.
Python
73
star
12

OpenGemini

Effort to open-source 10.5 trillion parameter Gemini model.
17
star
13

gordicaleksa

GitHub's new feature: repo with the same name as your GitHub name initialized with README.md will show on your landing page!
12
star
14

digital-image-processing

Projects I did for the Digital Image Processing course on my university
MATLAB
7
star
15

streamlit_playground

Simple Streamlit app.
Python
4
star
16

Open-NLLB-stopes

A library for preparing data for machine translation research (monolingual preprocessing, bitext mining, etc.) for the Open-NLLB effort.
Python
3
star
17

MachineLearningMicrosoftPetnica

Problems I solved for Microsoft ML summer camp in Petnica, Serbia
C++
3
star
18

competitive_programming

Contains algorithms and snippets I found useful when solving problems for TopCoder, Google Code Jam etc.
C++
2
star
19

slovenian-llm-eval

Slovenian LLM Eval.
Python
2
star
20

MicrosoftBubbleCup2018

My solutions for Bubble Cup 2018
C++
1
star
21

.dotfiles

Configuration files for my vim editor, bash etc.
Shell
1
star
22

GoogleCodeJam2018

My solutions for Google Code Jam 2018
C++
1
star