Unsupervised Representation Learning from Pre-trained Diffusion Probabilistic Models (PDAE)
This repository is official PyTorch implementation of PDAE (NeurIPS 2022).
@inproceedings{zhang2022unsupervised,
title={Unsupervised Representation Learning from Pre-trained Diffusion Probabilistic Models},
author={Zhang, Zijian and Zhao, Zhou and Lin, Zhijie},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
Dataset
We use the LMDB ready-to-use datasets provided by Diff-AE (https://github.com/phizaz/diffae#lmdb-datasets).
The directory structure should be:
data
├─horse
| ├─data.mdb
| â””lock.mdb
├─ffhq
| ├─data.mdb
| â””lock.mdb
├─celebahq
| ├─CelebAMask-HQ-attribute-anno.txt
| ├─data.mdb
| â””lock.mdb
├─celeba64
| ├─data.mdb
| â””lock.mdb
├─bedroom
| ├─data.mdb
| â””lock.mdb
Download
pre-trained-dpms (required)
trained-models (optional)
You should put download in the root dicretory of this project and maintain their directory structure as shown in Google Drive.
Training
To train DDPM, run this command:
cd ./trainer
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_regular.py --world_size 4
To train PDAE, run this command:
cd ./trainer
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_representation_learning.py --world_size 4
To train a classifier for manipulation, run this command:
cd ./trainer
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_manipulation_diffusion.py --world_size 4
To train a latent DPM, run this command:
cd ./trainer
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_latent_diffusion.py --world_size 4
You can change the config file and run path in the python file.
Evaluation
cd ./sampler
CUDA_VISIBLE_DEVICES=0 python3 autoencoding_example.py
cd ./sampler
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 autoencoding_eval.py --world_size 4
PDAE achieves autoencoding reconstruction SOTA performance of SSIM(0.994) and MSE(3.84e-5) when using inferred
cd ./sampler
CUDA_VISIBLE_DEVICES=0 python3 denoise_one_step.py
cd ./sampler
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 gap_measure.py --world_size 4
cd ./sampler
CUDA_VISIBLE_DEVICES=0 python3 interpolation.py
cd ./sampler
CUDA_VISIBLE_DEVICES=0 python3 manipulation.py
cd ./sampler
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 unconditional_sample.py --world_size 4