• Stars
    star
    194
  • Rank 199,026 (Top 4 %)
  • Language
    Python
  • License
    Other
  • Created almost 3 years ago
  • Updated 11 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Score-Based Generative Modeling with Critically-Damped Langevin Diffusion

PWC

Score-Based Generative Modeling
with Critically-Damped Langevin Diffusion

ICLR 2022 (spotlight)

Tim Dockhorn   ·   Arash Vahdat   ·   Karsten Kreis

Paper   Project Page


Animation

Requirements

CLD-SGM is built in Python 3.8.0 using PyTorch 1.8.1 and CUDA 11.1. Please use the following command to install the requirements:

pip install --upgrade pip
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html -f https://storage.googleapis.com/jax-releases/jax_releases.html

Optionally, you may also install NVIDIA Apex. The Adam optimizer from this library is faster than PyTorch's native Adam.

Preparations

CIFAR-10 does not require any data preparation as the data will be downloaded directly. To download CelebA-HQ-256 and prepare the dataset for training models, please run the following lines:

mkdir -p data/celeba/
wget -P data/celeba/ https://openaipublic.azureedge.net/glow-demo/data/celeba-tfr.tar
tar -xvf data/celeba/celeba-tfr.tar -C data/celeba/
python util/convert_tfrecord_to_lmdb.py --dataset=celeba --tfr_path=data/celeba/celeba-tfr --lmdb_path=data/celeba/celeba-lmdb --split=train
python util/convert_tfrecord_to_lmdb.py --dataset=celeba --tfr_path=data/celeba/celeba-tfr --lmdb_path=data/celeba/celeba-lmdb --split=validation

For multi-node training, the following environment variables need to be specified: $IP_ADDR is the IP address of the machine that will host the process with rank 0 during training (see here). $NODE_RANK is the index of each node among all the nodes.

Checkpoints

We provide pre-trained CLD-SGM checkpoints for CIFAR-10 and CelebA-HQ-256 here.

Training and evaluation

CIFAR-10
  • Training our CIFAR-10 model on a single node with one GPU and batch size 64:
python main.py -cc configs/default_cifar10.txt -sc configs/specific_cifar10.txt --root $ROOT --mode train --workdir work_dir/cifar10 --n_gpus_per_node 1 --training_batch_size 64 --testing_batch_size 64 --sampling_batch_size 64

Hidden flags can be found in the config files: configs/default_cifar10.txt and configs/specific_cifar10.txt. The flag --sampling_batch_size indicates the batch size per GPU, whereas --training_batch_size and --eval_batch_size indicate the total batch size of all GPUs combined. The script will update a running checkpoint every --snapshot_freq iterations (saved, in this case, at work_dir/cifar10/checkpoints/checkpoint.pth), starting from --snapshot_threshold. In configs/specific_cifar10.txt, these values are set to 10000 and 1, respectively.

  • Training our CIFAR-10 model on two nodes with 8 GPUs each and batch size 128:
mpirun --allow-run-as-root -np 2 -npernode 1 bash -c 'python main.py -cc configs/default_cifar10.txt -sc configs/specific_cifar10.txt --root $ROOT --mode train --workdir work_dir/cifar10 --n_gpus_per_node 8 --training_batch_size 8 --testing_batch_size 8 --sampling_batch_size 128 --node_rank $NODE_RANK --n_nodes 2 --master_address $IP_ADDR'
  • To resume training, we simply change the mode from train to continue (two nodes of 8 GPUs):
mpirun --allow-run-as-root -np 2 -npernode 1 bash -c 'python main.py -cc configs/default_cifar10.txt -sc configs/specific_cifar10.txt --root $ROOT --mode continue --workdir work_dir/cifar10 --n_gpus_per_node 8 --training_batch_size 8 --testing_batch_size 8 --sampling_batch_size 128 --cont_nbr 1 --node_rank $NODE_RANK --n_nodes 2 --master_address $IP_ADDR'

Any file within work_dir/cifar10/checkpoints/ can be used to resume training by setting --checkpoint to the particular file name. If --checkpoint is unspecified, the script automatically uses the last snapshot checkpoint (checkpoint.pth) to continue training. The flag --cont_nbr makes sure that a new random seed is used for training continuation; for additional continuation runs --cont_nbr may be incremented by one.

  • The following command can be used to evaluate the negative ELBO as well as the FID score (two nodes of 8 GPUs):
mpirun --allow-run-as-root -np 2 -npernode 1 bash -c 'python main.py -cc configs/default_cifar10.txt -sc configs/specific_cifar10.txt --root $ROOT --mode eval --workdir work_dir/cifar10 --n_gpus_per_node 8 --training_batch_size 8 --testing_batch_size 8 --sampling_batch_size 128 --eval_folder eval_elbo_and_fid --ckpt_file checkpoint_file --eval_likelihood --eval_fid --node_rank $NODE_RANK --n_nodes 2 --master_address $IP_ADDR'

Before running this you need to download the FID stats file from here and place it into $ROOT/assets/stats/).

To evaluate our provided CIFAR-10 model download the checkpoint here, create a directory work_dir/cifar10_pretrained_seed_0/checkpoints, place the checkpoint in it, and set --ckpt_file checkpoint_800000.pth as well as --workdir cifar10_pretrained.

CelebA-HQ-256
  • Training the CelebA-HQ-256 model from our paper (two nodes of 8 GPUs and batch size 64):
mpirun --allow-run-as-root -np 2 -npernode 1 bash -c 'python main.py -cc configs/default_celeba_paper.txt -sc configs/specific_celeba_paper.txt --root $ROOT --mode train --workdir work_dir/celeba_paper --n_gpus_per_node 8 --training_batch_size 4 --testing_batch_size 4 --sampling_batch_size 64 --data_location data/celeba/celeba-lmdb/ --node_rank $NODE_RANK --n_nodes 2 --master_address $IP_ADDR'

We found that training of the above model can potentially be unstable. Some modifications that we found post-publication lead to better numerical stability as well as improved performance:

mpirun --allow-run-as-root -np 2 -npernode 1 bash -c 'python main.py -cc configs/default_celeba_post_paper.txt -sc configs/specific_celeba_post_paper.txt --root $ROOT --mode train --workdir work_dir/celeba_post_paper --n_gpus_per_node 8 --training_batch_size 4 --testing_batch_size 4 --sampling_batch_size 64 --data_location data/celeba/celeba-lmdb/ --node_rank $NODE_RANK --n_nodes 2 --master_address $IP_ADDR'

In contrast to the model reported in our paper, we make use of a non-constant time reparameterization function β(t). For more details, please check the config files.

Toy data
  • Training on the multimodal Swiss Roll dataset using a single node with one GPU and batch size 512:
python main.py -cc configs/default_toy_data.txt --root $ROOT --mode train --workdir work_dir/multi_swiss_roll --n_gpus_per_node 1 --training_batch_size 512 --testing_batch_size 512 --sampling_batch_size 512 --dataset multimodal_swissroll

Additional toy datasets can be implemented in util/toy_data.py.

Monitoring the training process

We use Tensorboard to monitor the progress of training. For example, monitoring the CIFAR-10 process can be done as follows:

tensorboard --logdir work_dir/cifar10_seed_0/tensorboard

Demonstration

Load our pretrained checkpoints and play with sampling and likelihood computation:

Link Description
Open In Colab CIFAR-10
Open In Colab CelebA-HQ-256

Citation

If you find the code useful for your research, please consider citing our ICLR paper:

@inproceedings{dockhorn2022score,
  title={Score-Based Generative Modeling with Critically-Damped Langevin Diffusion},
  author={Tim Dockhorn and Arash Vahdat and Karsten Kreis},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2022}
}

License

Copyright © 2022, NVIDIA Corporation. All rights reserved.

This work is made available under the NVIDIA Source Code License. Please see our main LICENSE file.

License Dependencies

For any code dependencies related to StyleGAN2, the license is the Nvidia Source Code License-NC by NVIDIA Corporation, see StyleGAN2 LICENSE.

This code it built on the excellent ScoreSDE codebase by Song et al., which can be found here. For any code dependencies related to ScoreSDE, the license is the Apache License 2.0, see ScoreSDE LICENSE.

More Repositories

1

GET3D

Python
4,178
star
2

lift-splat-shoot

Lift, Splat, Shoot: Encoding Images from Arbitrary Camera Rigs by Implicitly Unprojecting to 3D (ECCV 2020)
Python
986
star
3

GSCNN

Gated-Shape CNN for Semantic Segmentation (ICCV 2019)
Python
916
star
4

nglod

Neural Geometric Level of Detail: Real-time Rendering with Implicit 3D Shapes (CVPR 2021 Oral)
Python
857
star
5

ASE

Python
745
star
6

LION

Latent Point Diffusion Models for 3D Shape Generation
Python
735
star
7

NKSR

[CVPR 2023 Highlight] Neural Kernel Surface Reconstruction
Python
735
star
8

DIB-R

Learning to Predict 3D Objects with an Interpolation-based Differentiable Renderer (NeurIPS 2019)
Python
653
star
9

editGAN_release

Python
629
star
10

FlexiCubes

Python
566
star
11

STEAL

STEAL - Learning Semantic Boundaries from Noisy Annotations (CVPR 2019)
Jupyter Notebook
477
star
12

datasetGAN_release

Python
340
star
13

ATISS

Code for "ATISS: Autoregressive Transformers for Indoor Scene Synthesis", NeurIPS 2021
Python
255
star
14

XCube

[CVPR 2024 Highlight] XCube: Large-Scale 3D Generative Modeling using Sparse Voxel Hierarchies
Python
240
star
15

vqad

225
star
16

vid2player3d

Official implementation for SIGGRAPH 2023 paper "Learning Physically Simulated Tennis Skills from Broadcast Videos"
Python
223
star
17

GameGAN_code

Learning to Simulate Dynamic Environments with GameGAN (CVPR 2020)
Python
222
star
18

semanticGAN_code

Official repo for SemanticGAN https://nv-tlabs.github.io/semanticGAN/
Python
180
star
19

meta-sim

Meta-Sim: Learning to Generate Synthetic Datasets (ICCV 2019)
Python
171
star
20

DefTet

Learning Deformable Tetrahedral Meshes for 3D Reconstruction (NeurIPS 2020)
Cuda
163
star
21

PADL

105
star
22

STRIVE

Code for CVPR 2022 paper "Generating Useful Accident-Prone Driving Scenarios via a Learned Traffic Prior"
Python
104
star
23

DriveGAN_code

Code release for DriveGAN (CVPR 2021)
CSS
93
star
24

3DiffTection

88
star
25

GENIE

GENIE: Higher-Order Denoising Diffusion Solvers
Python
88
star
26

bigdatasetgan_code

project page: https://nv-tlabs.github.io/big-datasetgan/
Python
87
star
27

stmc

Implementation of "Multi-Track Timeline Control for Text-Driven 3D Human Motion Generation" from CVPR Workshop on Human Motion Generation 2024.
Python
77
star
28

DPDM

Differentially Private Diffusion Models
Python
76
star
29

AUV-NET

Python
75
star
30

DIB-R-Single-Image-3D-Reconstruction

Python
73
star
31

trace

Official implementation of TRACE, the TRAjectory Diffusion Model for Controllable PEdestrians, from the CVPR 2023 paper: "Trace and Pace: Controllable Pedestrian Animation via Guided Trajectory Diffusion".
Python
68
star
32

pacer

Official implementation of PACER, Pedestrian Animation ControllER, of CVPR 2023 paper: "Trace and Pace: Controllable Pedestrian Animation via Guided Trajectory Diffusion".
Python
52
star
33

planning-centric-metrics

Learning to Evaluate Perception Models Using Planner-Centric Metrics
Python
52
star
34

DiffusionTexturePainting

[SIGGRAPH 2024] Diffusion Texture Painting
Python
51
star
35

editGAN

43
star
36

meta-sim-structure

Meta-Sim2: Unsupervised Learning of Scene Structure for Synthetic Data Generation (ECCV 2020)
31
star
37

GANverse3D

27
star
38

gameGAN

Project page for GameGAN
CSS
26
star
39

VideoLDM

HTML
24
star
40

brushstroke_engine

Code accompanying Neural Brushstroke Engine paper, SIGGRAPH Asia 2022
Jupyter Notebook
23
star
41

3DStyleNet

18
star
42

nv-tlabs.github.io

NVIDIA Toronto AI Lab public website
HTML
16
star
43

fDAL

Python
14
star
44

MvDeCor

Python
13
star
45

semanticGAN

https://nv-tlabs.github.io/semanticGAN/
13
star
46

compact-ngp

13
star
47

fed-sim

Federated Simulation for Medical Imaging (MICCAI2020)
11
star
48

DP-Sinkhorn_code

Python
11
star
49

DMTet

HTML
10
star
50

big-datasetgan

https://nv-tlabs.github.io/big-datasetgan/
HTML
9
star
51

datasetGAN

8
star
52

fegr

HTML
8
star
53

NTG

NTG - Neural Turtle Graphics for Modeling City Road Layouts (ICCV 2019)
8
star
54

inverse-rendering-3d-lighting

Project page for "Learning Indoor Inverse Rendering with 3D Spatially-Varying Lighting" (ICCV 2021)
7
star
55

flexicubes_website

5
star
56

tesmo

Official implementation of TeSMo, a method for text-controlled scene-aware motion generation, from the ECCV 2024 paper: "Generating Human Interaction Motions in Scenes with Text Control".
5
star
57

nkf

Project page of Neural Fields as Learnable Kernels for 3D Reconstruction.
HTML
4
star
58

XDGAN

XDGAN: Multi-Modal 3D Shape Generation in 2D Space
HTML
4
star
59

DriveGAN

CSS
3
star
60

physics-pose-estimation-project-page

HTML
3
star
61

outdoor-ar

HTML
3
star
62

hipnet

CSS
3
star
63

simulation-strategies

Towards Optimal Strategies for Training Self-Driving Perception Models in Simulation
2
star
64

equivariant

CSS
2
star
65

estimatingrequirements

Project page for the paper "How Much More Data Do I Need? Estimating Requirements For Downstream Tasks".
HTML
2
star
66

adaptive-shells-website

HTML
2
star
67

LearnOptimizeCollect

Project page for the paper "Optimizing Data Collection In Machine Learning"
HTML
1
star
68

DP-Sinkhorn

Project page for DP-Sinkhorn (Neurips 2021)
HTML
1
star
69

PMGAN

CSS
1
star
70

hugo-backend

hugo backend for the main page
Shell
1
star
71

lip-mlp

HTML
1
star
72

unicon

HTML
1
star
73

DIBRPlus

1
star