• Stars
    star
    491
  • Rank 89,636 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated 11 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Conditional diffusion model to generate MNIST. Minimal script. Based on 'Classifier-Free Diffusion Guidance'.

Conditional Diffusion MNIST

script.py is a minimal, self-contained implementation of a conditional diffusion model. It learns to generate MNIST digits, conditioned on a class label. The neural network architecture is a small U-Net. This code is modified from this excellent repo which does unconditional generation. The diffusion model is a Denoising Diffusion Probabilistic Model (DDPM).

Samples generated from the model.

The conditioning roughly follows the method described in Classifier-Free Diffusion Guidance (also used in ImageGen). The model infuses timestep embeddings $t_e$ and context embeddings $c_e$ with the U-Net activations at a certain layer $a_L$, via,

$a_{L+1} = c_e a_L + t_e.$

(Though in our experimentation, we found variants of this also work, e.g. concatenating embeddings together.)

At training time, $c_e$ is randomly set to zero with probability $0.1$, so the model learns to do unconditional generation (say $\psi(z_t)$ for noise $z_t$ at timestep $t$) and also conditional generation (say $\psi(z_t, c)$ for context $c$). This is important as at generation time, we choose a weight, $w \geq 0$, to guide the model to generate examples with the following equation,

$\hat{\epsilon}_{t} = (1+w)\psi(z_t, c) - w \psi(z_t).$

Increasing $w$ produces images that are more typical but less diverse.

Samples produced with varying guidance strength, $w$.

Training for above models took around 20 epochs (~20 minutes).