P2 weighting (CVPR 2022)
This is the codebase for Perception Prioritized Training of Diffusion Models.
This repository is heavily based on openai/guided-diffusion.
P2 modifies the weighting scheme of the training objective function to improve sample quality. It encourages the diffusion model to focus on recovering signals from highly corrupted data, where the model learns global and perceptually rich concepts. Below figure shows the weighting schemes in terms of SNR.
Pre-trained models
All models are trained at 256x256 resolution.
Here are the models trained on FFHQ, CelebA-HQ, CUB, AFHQ-Dogs, Flowers, and MetFaces: drive
Requirements
We tested on PyTorch 1.7.1, single RTX8000 GPU.
Sampling from pre-trained models
First, set PYTHONPATH variable to point to the root of the repository. Do the same when training new models.
export PYTHONPATH=$PYTHONPATH:$(pwd)
Put model checkpoints into a folder models/
.
Samples will be saved in samples/
.
python scripts/image_sample.py --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 250 --model_path models/ffhq_p2.pt --sample_dir samples
To sample for 250 timesteps without DDIM, replace --timestep_respacing ddim25
to --timestep_respacing 250
, and replace --use_ddim True
with --use_ddim False
.
Training your models
--p2_gamma
and --p2_k
are two hyperparameters of P2 weighting. We used --p2_gamma 0.5 --p2_k 1
and --p2_gamma 1 --p2_k 1
in the paper.
Logs and models will be saved in logs/
. You should modify --data_dir
.
We used lightweight version (93M parameter) of ADM (over 500M) as default model configuration. You may modify the model.
python scripts/image_train.py --data_dir data/DATASET_NAME --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 1 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --lr 2e-5 --batch_size 8 --rescale_learned_sigmas True --p2_gamma 1 --p2_k 1 --log_dir logs