• Stars
    star
    152
  • Rank 243,430 (Top 5 %)
  • Language
    Python
  • License
    Other
  • Created over 1 year ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Official repo for Discriminator Guidance.

Refining Generative Process with Discriminator Guidance in Score-based Diffusion Models (DG) (ICML 2023 Oral)
Official PyTorch implementation of the Discriminator Guidance

Dongjun Kim *, Yeongmin Kim *, Se Jung Kwon, Wanmo Kang, and Il-Chul Moon
* Equal contribution

| paper |
See alsdudrla10/DG_imagenet for ImageNet256 code release.

Overview

Teaser image

Step-by-Step Running of Discriminator Guidance

1) Prepare a pre-trained score network

  • Download edm-cifar10-32x32-uncond-vp.pkl at EDM for unconditional model.
  • Download edm-cifar10-32x32-cond-vp.pkl at EDM for conditional model.
  • Place EDM checkpoint at the directory specified below.
${project_page}/DG/
├── checkpoints
│   ├── pretrained_score/edm-cifar10-32x32-uncond-vp.pkl
│   ├── pretrained_score/edm-cifar10-32x32-cond-vp.pkl
├── ...

2) Generate fake samples

  • To draw 50k unconditional samples, run:
python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-uncond-vp.pkl --outdir=samples/cifar_uncond_vanilla --dg_weight_1st_order=0
  • To draw 50k conditional samples, run:
python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-cond-vp.pkl --outdir=samples/cifar_cond_vanilla --dg_weight_1st_order=0

3) Prepare real data

${project_page}/DG/
├── data
│   ├── true_data.npz
│   ├── true_data_label.npz
├── ...

4) Prepare a pre-trained classifier

${project_page}/DG/
├── checkpoints
│   ├── ADM_classifier/32x32_classifier.pt
├── ...

5) Train a discriminator

${project_page}/DG/
├── checkpoints/discriminator
│   ├── cifar_uncond/discriminator_60.pt
│   ├── cifar_cond/discriminator_250.pt
├── ...
  • To train the unconditional discriminator from scratch, run:
python3 train.py
  • To train the conditional discriminator from scratch, run:
python3 train.py --savedir=/checkpoints/discriminator/cifar_cond --gendir=/samples/cifar_cond_vanilla --datadir=/data/true_data_label.npz --cond=1 

6) Generate discriminator-guided samples

  • To generate unconditional discriminator-guided 50k samples, run:
python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-uncond-vp.pkl --outdir=samples/cifar_uncond
  • To generate conditional discriminator-guided 50k samples, run:
python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-cond-vp.pkl --outdir=samples/cifar_cond --dg_weight_1st_order=1 --cond=1 --discriminator_ckpt=/checkpoints/discriminator/cifar_cond/discriminator_250.pt --boosting=1

7) Evaluate FID

${project_page}/DG/
├── stats
│   ├── cifar10-32x32.npz
├── ...
  • Run:
python3 fid_npzs.py --ref=/stats/cifar10-32x32.npz --num_samples=50000 --images=/samples/cifar_uncond/
python3 fid_npzs.py --ref=/stats/cifar10-32x32.npz --num_samples=50000 --images=/samples/cifar_cond/

Experimental Results

EDM-G++

FID-50k Cifar-10 Cifar-10(conditional) FFHQ64
EDM 2.03 1.82 2.39
EDM-G++ 1.77 1.64 1.98

Other backbones

FID-50k Cifar-10 CelebA64
Backbone 2.10 1.90
Backbone-G++ 1.94 1.34

Note that we use LSGM for Cifar-10 backbone, and Soft-Truncation for CelebA64 backbone.
See alsdudrla10/DG_imagenet for the results and released code on ImageNet256.

Samples from unconditional Cifar-10

Teaser image

Samples from conditional Cifar-10

Teaser image

Reference

If you find the code useful for your research, please consider citing

@article{kim2022refining,
  title={Refining Generative Process with Discriminator Guidance in Score-based Diffusion Models},
  author={Kim, Dongjun and Kim, Yeongmin and Kang, Wanmo and Moon, Il-Chul},
  journal={arXiv preprint arXiv:2211.17091},
  year={2022}
}

This work is heavily built upon the code from

  • Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the design space of diffusion-based generative models. arXiv preprint arXiv:2206.00364.
  • Dhariwal, P., & Nichol, A. (2021). Diffusion models beat gans on image synthesis. Advances in Neural Information Processing Systems, 34, 8780-8794.
  • Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456.