• Stars
    star
    2,728
  • Rank 16,698 (Top 0.4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 5 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch

StyleGAN 2 in PyTorch

Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch

Notice

I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care.

Requirements

I have tested on:

  • PyTorch 1.3.1
  • CUDA 10.1/10.2

Usage

First create lmdb datasets:

python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH

This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later.

Then you can train model in distributed settings

python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH

train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script.

SWAGAN

This implementation experimentally supports SWAGAN: A Style-based Wavelet-driven Generative Model (https://arxiv.org/abs/2102.06108). You can train SWAGAN by using

python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --arch swagan --batch BATCH_SIZE LMDB_PATH

As noted in the paper, SWAGAN trains much faster. (About ~2x at 256px.)

Convert weight from official checkpoints

You need to clone official repositories, (https://github.com/NVlabs/stylegan2) as it is requires for load official checkpoints.

For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-ffhq-config-f.pkl, You can convert it like this:

python convert_weight.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl

This will create converted stylegan2-ffhq-config-f.pt file.

Generate samples

python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT

You should change your size (--size 256 for example) if you train with another dimension.

Project images to latent spaces

python projector.py --ckpt [CHECKPOINT] --size [GENERATOR_OUTPUT_SIZE] FILE1 FILE2 ...

Closed-Form Factorization (https://arxiv.org/abs/2007.06600)

You can use closed_form_factorization.py and apply_factor.py to discover meaningful latent semantic factor or directions in unsupervised manner.

First, you need to extract eigenvectors of weight matrices using closed_form_factorization.py

python closed_form_factorization.py [CHECKPOINT]

This will create factor file that contains eigenvectors. (Default: factor.pt) And you can use apply_factor.py to test the meaning of extracted directions

python apply_factor.py -i [INDEX_OF_EIGENVECTOR] -d [DEGREE_OF_MOVE] -n [NUMBER_OF_SAMPLES] --ckpt [CHECKPOINT] [FACTOR_FILE]

For example,

python apply_factor.py -i 19 -d 5 -n 10 --ckpt [CHECKPOINT] factor.pt

Will generate 10 random samples, and samples generated from latents that moved along 19th eigenvector with size/degree +-5.

Sample of closed form factorization

Pretrained Checkpoints

Link

I have trained the 256px model on FFHQ 550k iterations. I got FID about 4.5. Maybe data preprocessing, resolution, training loop could made this difference, but currently I don't know the exact reason of FID differences.

Samples

Sample with truncation

Sample from FFHQ. At 110,000 iterations. (trained on 3.52M images)

MetFaces sample with non-leaking augmentations

Sample from MetFaces with Non-leaking augmentations. At 150,000 iterations. (trained on 4.8M images)

Samples from converted weights

Sample from FFHQ

Sample from FFHQ (1024px)

Sample from LSUN Church

Sample from LSUN Church (256px)

License

Model details and custom CUDA kernel codes are from official repostiories: https://github.com/NVlabs/stylegan2

Codes for Learned Perceptual Image Patch Similarity, LPIPS came from https://github.com/richzhang/PerceptualSimilarity

To match FID scores more closely to tensorflow official implementations, I have used FID Inception V3 implementations in https://github.com/mseitzer/pytorch-fid

More Repositories

1

vq-vae-2-pytorch

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
Python
1,606
star
2

style-based-gan-pytorch

Implementation A Style-Based Generator Architecture for Generative Adversarial Networks in PyTorch
Python
1,099
star
3

glow-pytorch

PyTorch implementation of Glow
Python
508
star
4

alias-free-gan-pytorch

Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch
Python
507
star
5

denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Models in PyTorch
Python
360
star
6

ml-papers

My collection of machine learning papers
267
star
7

swapping-autoencoder-pytorch

Unofficial implementation of Swapping Autoencoder for Deep Image Manipulation (https://arxiv.org/abs/2007.00653) in PyTorch
Python
255
star
8

mac-network-pytorch

Memory, Attention and Composition (MAC) Network for CLEVR implemented in PyTorch
Python
85
star
9

vision-transformers-pytorch

Implementation of various Vision Transformers I found interesting
Python
83
star
10

adaptive-softmax-pytorch

Adaptive Softmax implementation for PyTorch
Python
79
star
11

sagan-pytorch

Self-Attention Generative Adversarial Networks Implementation in PyTorch
Python
74
star
12

igebm-pytorch

Implicit Generation and Generalization in Energy Based Models in PyTorch
Python
65
star
13

depthwise-conv-pytorch

Faster depthwise convolutions for PyTorch
Cuda
64
star
14

ocr-pytorch

Object-Contextual Representations for Semantic Segmentation in PyTorch
Python
63
star
15

progressive-gan-pytorch

Implemetatin of Progressive Growing of GANs in PyTorch
Python
62
star
16

relation-networks-pytorch

Relation Networks for CLEVR implemented in PyTorch
Python
61
star
17

imputer-pytorch

Implementation of Imputer: Sequence Modelling via Imputation and Dynamic Programming in PyTorch
Python
58
star
18

fcos-pytorch

Re-implementation of FCOS for personal study
Python
51
star
19

knotter

Implementation of Mapper algorithm for Topological Data Analysis
JavaScript
45
star
20

semantic-pyramid-pytorch

Implementation of Semantic Pyramid for Image Generation (https://arxiv.org/abs/2003.06221) in PyTorch
Python
39
star
21

id-gan-pytorch

Information Distillation Generative Adversrial Network in PyTorch
Python
27
star
22

nerf-pytorch

Python
21
star
23

tensorfn

Weakly opinionated library for implementing ML models. Less boilerplate, More rigor
Python
20
star
24

taming-transformers-pytorch

Implementation of Taming Transformers for High-Resolution Image Synthesis (https://arxiv.org/abs/2012.09841) in PyTorch
16
star
25

film-pytorch

Just another implementation of FiLM in PyTorch
Python
14
star
26

melgan-pytorch

MelGAN and Tacotron 2 in PyTorch
Python
11
star
27

instant-ngp-pytorch

Study for Instant neural graphics primitives (Unofficial)
11
star
28

meshfn

Framework for Human Alignment Learning
Python
7
star
29

nansy-pytorch

Unofficial implementation of Neural Analysis and Synthesis
7
star
30

sarigan-pytorch

Unofficial implementation of Learning Semantic-aware Normalization for Generative Adversarial Networks (SariGAN) in PyTorch
7
star
31

lvpga-pytorch

Implementation of Perceptual Generative Autoencoders in PyTorch
Python
5
star
32

arxiv-sanity

arXiv feed tool that heavily inspired by Arxiv Sanity Preserver
Python
5
star
33

esrgan-pytorch

ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks in PyTorch
3
star
34

dockerfiles

dockerfiles
Dockerfile
3
star
35

sujip

Non-opionated utility library for PyTorch
Python
2
star
36

rosinality.github.io

HTML
2
star
37

small-logan-pytorch

Small-GAN and LOGAN in PyTorch
2
star
38

maskrcnn-pytorch

Re-implementation of Mask R-CNN for personal study
2
star
39

synapticmap

Synaptic Map - Simple mindmapping program with directional connections
JavaScript
1
star
40

langfn

A DSL for LLMs
1
star
41

usrnet-pytorch

Reimplementation of Deep Unfolding Network for Image Super-Resolution for self study.
1
star
42

fill-blank

Paragraph embedding by solving the fill in the blank problems
Python
1
star
43

centernet-pytorch

Re-implementation of CenterNet for personal study
1
star