• Stars
    star
    275
  • Rank 149,796 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated over 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

The official PyTorch implementation for NCSNv2 (NeurIPS 2020)

Improved Techniques for Training Score-Based Generative Models

This repo contains the official implementation for the paper Improved Techniques for Training Score-Based Generative Models.

by Yang Song and Stefano Ermon, Stanford AI Lab.

Note: The method has been extended by the subsequent work Score-Based Generative Modeling through Stochastic Differential Equations (code) that allows better sample quality and exact log-likelihood computation.


We significantly improve the method proposed in Generative Modeling by Estimating Gradients of the Data Distribution. Score-based generative models are flexible neural networks trained to capture the score function of an underlying data distributionโ€”a vector field pointing to directions where the data density increases most rapidly. We present new techniques to improve the performance of score-based generative models, scaling them to high resolution images that are previously impossible. Without requiring adversarial training, they can produce sharp and diverse image samples that rival GANs.

samples

(From left to right: Our samples on FFHQ 256px, LSUN bedroom 128px, LSUN tower 128px, LSUN church_outdoor 96px, and CelebA 64px.)

Running Experiments

Dependencies

Run the following to install all necessary python packages for our code.

pip install -r requirements.txt

Project structure

main.py is the file that you should run for both training and sampling. Execute python main.py --help to get its usage description:

usage: main.py [-h] --config CONFIG [--seed SEED] [--exp EXP] --doc DOC
               [--comment COMMENT] [--verbose VERBOSE] [--test] [--sample]
               [--fast_fid] [--resume_training] [-i IMAGE_FOLDER] [--ni]

optional arguments:
  -h, --help            show this help message and exit
  --config CONFIG       Path to the config file
  --seed SEED           Random seed
  --exp EXP             Path for saving running related data.
  --doc DOC             A string for documentation purpose. Will be the name
                        of the log folder.
  --comment COMMENT     A string for experiment comment
  --verbose VERBOSE     Verbose level: info | debug | warning | critical
  --test                Whether to test the model
  --sample              Whether to produce samples from the model
  --fast_fid            Whether to do fast fid test
  --resume_training     Whether to resume training
  -i IMAGE_FOLDER, --image_folder IMAGE_FOLDER
                        The folder name of samples
  --ni                  No interaction. Suitable for Slurm Job launcher

Configuration files are in config/. You don't need to include the prefix config/ when specifying --config . All files generated when running the code is under the directory specified by --exp. They are structured as:

<exp> # a folder named by the argument `--exp` given to main.py
โ”œโ”€โ”€ datasets # all dataset files
โ”œโ”€โ”€ logs # contains checkpoints and samples produced during training
โ”‚   โ””โ”€โ”€ <doc> # a folder named by the argument `--doc` specified to main.py
โ”‚      โ”œโ”€โ”€ checkpoint_x.pth # the checkpoint file saved at the x-th training iteration
โ”‚      โ”œโ”€โ”€ config.yml # the configuration file for training this model
โ”‚      โ”œโ”€โ”€ stdout.txt # all outputs to the console during training
โ”‚      โ””โ”€โ”€ samples # all samples produced during training
โ”œโ”€โ”€ fid_samples # contains all samples generated for fast fid computation
โ”‚   โ””โ”€โ”€ <i> # a folder named by the argument `-i` specified to main.py
โ”‚      โ””โ”€โ”€ ckpt_x # a folder of image samples generated from checkpoint_x.pth
โ”œโ”€โ”€ image_samples # contains generated samples
โ”‚   โ””โ”€โ”€ <i>
โ”‚       โ””โ”€โ”€ image_grid_x.png # samples generated from checkpoint_x.pth       
โ””โ”€โ”€ tensorboard # tensorboard files for monitoring training
    โ””โ”€โ”€ <doc> # this is the log_dir of tensorboard

Training

For example, we can train an NCSNv2 on LSUN bedroom by running the following

python main.py --config bedroom.yml --doc bedroom

Log files will be saved in <exp>/logs/bedroom.

Sampling

If we want to sample from NCSNv2 on LSUN bedroom, we can edit bedroom.yml to specify the ckpt_id under the group sampling, and then run the following

python main.py --sample --config bedroom.yml -i bedroom

Samples will be saved in <exp>/image_samples/bedroom.

We can interpolate between different samples (see more details in the paper). Just set interpolation to true and an appropriate n_interpolations under the group of sampling in bedroom.yml. We can also perform other tasks such as inpainting. Usages should be quite obvious if you read the code and configuration files carefully.

Computing FID values quickly for a range of checkpoints

We can specify begin_ckpt and end_ckpt under the fast_fid group in the configuration file. For example, by running the following command, we can generate a small number of samples per checkpoint within the range begin_ckpt-end_ckpt for a quick (and rough) FID evaluation.

python main.py --fast_fid --config bedroom.yml -i bedroom

You can find samples in <exp>/fid_samples/bedroom.

Pretrained Checkpoints

Link: https://drive.google.com/drive/folders/1217uhIvLg9ZrYNKOR3XTRFSurt4miQrd?usp=sharing

You can produce samples using it on all datasets we tested in the paper. It assumes the --exp argument is set to exp.

References

If you find the code/idea useful for your research, please consider citing

@inproceedings{song2020improved,
  author    = {Yang Song and Stefano Ermon},
  editor    = {Hugo Larochelle and
               Marc'Aurelio Ranzato and
               Raia Hadsell and
               Maria{-}Florina Balcan and
               Hsuan{-}Tien Lin},
  title     = {Improved Techniques for Training Score-Based Generative Models},
  booktitle = {Advances in Neural Information Processing Systems 33: Annual Conference
               on Neural Information Processing Systems 2020, NeurIPS 2020, December
               6-12, 2020, virtual},
  year      = {2020}
}

and/or our previous work

@inproceedings{song2019generative,
  title={Generative Modeling by Estimating Gradients of the Data Distribution},
  author={Song, Yang and Ermon, Stefano},
  booktitle={Advances in Neural Information Processing Systems},
  pages={11895--11907},
  year={2019}
}

More Repositories

1

cs228-notes

Course notes for CS228: Probabilistic Graphical Models.
SCSS
1,908
star
2

ddim

Denoising Diffusion Implicit Models
Python
1,400
star
3

SDEdit

PyTorch implementation for SDEdit: Image Synthesis and Editing with Stochastic Differential Equations
Python
965
star
4

ncsn

Noise Conditional Score Networks (NeurIPS 2019, Oral)
Python
654
star
5

CSDI

Codes for "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation"
Jupyter Notebook
278
star
6

Wifi_Activity_Recognition

Code for IEEE Communication Magazine (A Survey on Behaviour Recognition Using WiFi Channle State Information)
Jupyter Notebook
248
star
7

Variational-Ladder-Autoencoder

Implementation of VLAE
Python
215
star
8

MA-AIRL

Multi-Agent Adversarial Inverse Reinforcement Learning, ICML 2019.
Python
193
star
9

sliced_score_matching

Code for reproducing results in the sliced score matching paper (UAI 2019)
Python
137
star
10

neuralsort

Code for "Stochastic Optimization of Sorting Networks using Continuous Relaxations", ICLR 2019.
Python
134
star
11

a-nice-mc

Code for "A-NICE-MC: Adversarial Training for MCMC"
Jupyter Notebook
126
star
12

tile2vec

Implementation and examples for Tile2Vec
Python
111
star
13

flow-gan

Code for "Flow-GAN: Combining Maximum Likelihood and Adversarial Learning in Generative Models", AAAI 2018.
Python
105
star
14

GraphScoreMatching

Official implementation for the paper: Permutation Invariant Graph Generation via Score-Based Generative Modeling
Python
104
star
15

Sequential-Variational-Autoencoder

Implementation of Sequential Variational Autoencoder
Python
84
star
16

multiagent-gail

Python
80
star
17

markov-chain-gan

Code for "Generative Adversarial Training for Markov Chains" (ICLR 2017 Workshop)
Python
80
star
18

ssdkl

Code that accompanies the paper Semi-supervised Deep Kernel Learning: Regression with Unlabeled Data by Minimizing Predictive Variance
Python
73
star
19

smile-mi-estimator

PyTorch implementation for the ICLR 2020 paper "Understanding the Limitations of Variational Mutual Information Estimators"
Jupyter Notebook
72
star
20

MetaIRL

Meta-Inverse Reinforcement Learning with Probabilistic Context Variables
Python
69
star
21

PatchDrop

PyTorch Implementation of `Learning to Process Fewer Pixels` - [CVPR20 (Oral)]
Python
66
star
22

generative_adversary

Code for the unrestricted adversarial examples paper (NeurIPS 2018)
Python
63
star
23

pirank

PiRank: Learning to Rank via Differentiable Sorting
Python
59
star
24

graphite

Code for Graphite iterative graph generation
Python
56
star
25

CalibratedModelBasedRL

Code for "Calibrated Model-Based Deep Reinforcement Learning", ICML 2019.
Python
54
star
26

ODS

Code for "Diversity can be Transferred: Output Diversification for White- and Black-box Attacks"
Python
52
star
27

subsets

Code for Reparameterizable Subset Sampling via Continuous Relaxations, IJCAI 2019.
Python
49
star
28

necst

Neural Joint-Source Channel Coding
Python
45
star
29

cs323-notes

Course notes for CS323: Automated Reasoning
CSS
40
star
30

mintnet

MintNet: Building Invertible Neural Networks with Masked Convolutions
Python
38
star
31

f-EBM

Code for "Training Deep Energy-Based Models with f-Divergence Minimization" ICML 2020
Python
35
star
32

alignflow

Python
33
star
33

higher_order_invariance

Code for "Accelerating Natural Gradient with Higher-Order Invariance"
MATLAB
29
star
34

lagvae

Lagrangian VAE
Python
28
star
35

BiasAndGeneralization

Jupyter Notebook
26
star
36

fast_feedforward_computation

Official code for "Accelerating Feedforward Computation via Parallel Nonlinear Equation Solving", ICML 2021
Jupyter Notebook
25
star
37

BCD-Nets

Code for `BCD Nets: Scalable Variational Approaches for Bayesian Causal Discovery`, Neurips 2021
Python
24
star
38

STGAN

PyTorch Implementation of STGAN for Cloud Removal in Satellite Images.
Python
24
star
39

self-similarity-prior

Self-Similarity Priors: Neural Collages as Differentiable Fractal Representations
Jupyter Notebook
23
star
40

Crop_Yield_Prediction

Python
23
star
41

NDA

Python
23
star
42

sparse_gen

Code for "Modeling Sparse Deviations for Compressed Sensing using Generative Models", ICML 2018
Python
23
star
43

dail

The Official Implementation of Domain Adaptive Imitation Learning (DAIL)
Python
22
star
44

lag-fairness

Python
22
star
45

f-dre

Featurized Density Ratio Estimation
Jupyter Notebook
20
star
46

bgm

Code for "Boosted Generative Models", AAAI 2018.
Python
20
star
47

fairgen

Fair Generative Modeling via Weak Supervision
Jupyter Notebook
19
star
48

best-arm-delayed

Code for "Best arm identification in multi-armed bandits with delayed feedback", AISTATS 2018.
Python
19
star
49

WikipediaPovertyMapping

Implementation of Geolocated Articles Processing and Poverty Mapping - [KDD19]
Jupyter Notebook
18
star
50

dre-infinity

Density Ratio Estimation via Infinitesimal Classification (AISTATS 2022 Oral)
Python
16
star
51

Neural-PDE-Solver

Python
15
star
52

SPN_Variational_Inference

PyTorch implementation for "Probabilistic Circuits for Variational Inference in Discrete Graphical Models", NeurIPS 2020
Python
15
star
53

acl

Code for "Adversarial Constraint Learning for Structured Prediction"
Python
14
star
54

f-wgan

Code for "Bridging the Gap between f-GANs and Wasserstein GANs", ICML 2020
Jupyter Notebook
14
star
55

HyperSPN

PyTorch implementation for "HyperSPNs: Compact and Expressive Probabilistic Circuits", NeurIPS 2021
Python
13
star
56

EfficientObjectDetection

PyTorch Implementation of Efficient Object Detection in Large Images
Python
8
star
57

streamline-vi-csp

C
7
star
58

bayes-opt

Python
4
star
59

BestArmIdentification

Python
3
star
60

pestat

Keep pestat great
Shell
3
star
61

permanent_adaptive

Python
3
star
62

rbpf_fireworks

Python
2
star
63

PretrainingWikiSatNet

Python
2
star
64

weighted-rademacher

Python
2
star
65

gac

Python
2
star