• Stars
    star
    322
  • Rank 129,651 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created over 1 year ago
  • Updated 4 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code for Fast Training of Diffusion Models with Masked Transformers

Fast Training of Diffusion Models with Masked Transformers

Official implementation of the paper Fast Training of Diffusion Models with Masked Transformers

Abstract: We propose an efficient approach to train large diffusion models with masked transformers. While masked transformers have been extensively explored for representation learning, their application to generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50%) of patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches. Experiments on ImageNet-256x256 show that our approach achieves the same performance as the state-of-the-art Diffusion Transformer (DiT) model, using only 31% of its original training time. Thus, our method allows for efficient training of diffusion models without sacrificing the generative performance.

Architecture

Training efficiency

Our MaskDiT applies Automatic Mixed Precision (AMP) by default. We also add the MaskDiT without AMP (Ours_ft32) for reference.

Requirements

  • We recommend training maskDiT on 8 A100 GPUs, which takes around 260 hours to perform 2M updates with a batch size of 1024.
  • At least one high-end GPU for sampling.
  • Dockerfile is provided for exact software environment.

Prepare dataset

We use the pre-trained VAE to first encode the ImageNet dataset into latent space. You can download the pre-trained VAE by using download_assets.py.

python3 download_assets.py --name vae --dest assets

You can also directly download the dataset we have prepared by running

python3 download_assets.py --name imagenet-latent-data --dest [destination directory]

Train

We first train MaskDiT with 50% mask ratio with AMP enabled.

python3 train_latent.py --config configs/train/maskdit-latent-imagenet.yaml --num_process_per_node 8

We then finetune with unmasking. For example,

python3 train_latent.py --config configs/finetune/maskdit-latent-imagenet-const.yaml --ckpt_path [path to checkpoint] --use_ckpt_path False --use_strict_load False --no_amp
Train on the original ImageNet. Click to expand.

We also provide code for training MaskDiT without pre-encoded dataset in train.py. This is only for reference. We did not fully test it. After preparing the original ImageNet dataset, run

python3 train.py --config configs/train/maskdit-imagenet.yaml --num_process_per_node 8

Generate samples

To generate samples from provided checkpoints, for example, run

python3 generate.py --config configs/train/maskdit-latent-imagenet.yaml --ckpt_path results/2075000.pt --class_idx 388 --cfg_scale 2.5

Checkpoints of MaskDiT can be downloaded by running download_assets.py. For example,

python3 download_assets.py --name maskdit-finetune0 --dest results

We provide the following checkpoints.

Generated samples from MaskDiT. Upper panel: without CFG. Lower panel: with CFG (scale=1.5).

Evaluation

First, download the reference from ADM repo directly. You can also use download_assets.py by running

python3 download_assets.py --name imagenet256 --dest [destination directory]

Then we use the evaluator evaluator.py from ADM repo, or fid.py from EDM repo, to evaluate the generated samples.

Generative performance on ImageNet-256x256. The area of each bubble indicates the FLOPs for a single forward pass during training.

Acknowledgements

Thanks to the open source codebases such as DiT, MAE, U-ViT, ADM, and EDM. Our codebase is built on them.