• Stars
    star
    262
  • Rank 156,136 (Top 4 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created about 1 year ago
  • Updated 7 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

The official implementation of "Relay Diffusion: Unifying diffusion process across resolutions for image synthesis" [ICLR 2024 Spotlight]

Relay Diffusion: Unifying diffusion process across resolutions for image synthesis
Official Pytorch Implementation

🎉News! The paper of RelayDiffusion has been accepted by ICLR 2024 (Spotlight)!

We propose Relay Diffusion Model (RDM) as a better framework for diffusion generation. RDM transfers a low-resolution image or noise into an equivalent high-resolution one via blurring diffusion and block noise. Therefore, the diffusion process can continue seamlessly in any new resolution or model without restarting from pure noise or low-resolution conditioning.

RDM achieved state-of-the-art FID on CelebA-HQ and sFID ImageNet-256 (FID=1.87)!

For a formal introduction, Read our paper: Relay Diffusion: Unifying diffusion process across resolutions for image synthesis.

Setup

Environment

Download the repo and setup the environment with:

git clone https://github.com/THUDM/RelayDiffusion.git
cd RelayDiffusion
conda env create -f environment.yml
conda activate rdm

We enable xformers.ops.memory_efficient_attention to reduce about 15% training cost. If there is no need you can also remove xformers from environment.yml.

Linux servers with Nvidia A100s are recommended. However, by setting smaller --batch-gpu (batch size on a single gpu), you can still run the inference and training scripts on less powerful GPUs.

Dataset

We preprocess and implement datasets with the same format as EDM. For CelebA-HQ, follow Progressive Growing of GANs for Improved Quality, Stability, and Variation to construct the high-quality subset of CelebA. For ImageNet, download data from the official site.

To convert the original data to organized data ready for training at $64\times 64$ or $256\times 256$ resolution, run command:

python dataset_tool.py \
	--source=/path/to/original/data \
	--dest=/path/to/output/data.zip \
    --transform=center-crop \
	--resolution=64x64 # or --resolution=256x256

Inference & Evaluation

Sample Generation

To generate samples from RDM models, run command:

torchrun --standalone --nproc_per_node=1 generate.py --sampler_stages=both --outdir=/path/to/output/dir/ \
    --network_first=/path/to/1st/ckpt --network_second=/path/to/2nd/ckpt

To generate $N$ images, set --seed=[K]-[K+N-1] with a randomly-picked $K$. You can assign --nproc_per_node=N to enable parallel generation of multiple GPUs.

If you want to generate final samples from first-stage results (only use the second stage model), set --sampler_stages=second and assign input directory of first-stage results by --indir.

Besides, arguments for configurations of the first stage are:

  • num_steps_first: number of sampling steps.
  • sigma_min_first & sigma_max_first: lowest & highest noise level.
  • rho_first: time step exponent.
  • cfg_scale_first: scale of classifier-free guidance.
  • S_churn: stochasticity strength.
  • S_min & S_max: min & max noise level.
  • S_noise: noise inflation.

Arguments for configurations of the second stage are:

  • num_steps_second: number of sampling steps.
  • sigma_min_second & sigma_max_second: lowest & highest noise level.
  • blur_sigma_max_second: maximum sigma of blurring schedule.
  • rho_second: time step exponent.
  • cfg_scale_second: scale of classifier-free guidance.
  • up_scale_second: scale of upsampling.
  • truncation_sigma_second & truncation_t_second: truncation point of noise & time schedule.
  • s_block_second: strength of block noise addition.
  • s_noise_second: strength of stochasticity.

Evaluation Metrics

We quantitatively measure the sample quality by metrics including Fréchet inception distance (FID), spatial FID (sFID), Inception Score (IS), Precision and Recall. For sFID, IS, Precision and Recall, we reformat the calculation pipeline based on the formulation in tensorflow from ADM.

First, run the following command to generate activation data file from samples and dataset:

torchrun --standalone --nproc_per_node=1 evaluate.py activations --data=/sample/dir/ --dest=eval-refs/activations_sample.npz --batch=64 # build sample activations
torchrun --standalone --nproc_per_node=1 evaluate.py activations --data=/path/to/dataset.zip --dest=eval-refs/activations_ref.npz --batch=64 # build reference activations

Then calculate metrics based on pre-built activations, run command:

torchrun --standalone --nproc_per_node=1 evaluate.py calc --batch=64 \
    --activations_sample=eval-refs/activations_sample.npz \
    --activations_ref=eval-refs/activations_ref.npz \
    [-m fid] [-m sfid] [-m is] [-m pr] \ # assign metrics to be calculated

Performance Reproduction

RDM achieves competitive results in comparison with previous SoTA models:

Dataset Resolution Training Samples FID sFID IS Precision Recall
CelebA-HQ 256x256 47M 3.15 - - 0.77 0.55
ImageNet 256x256 1250M 1.87 3.97 278.75 0.81 0.59

We provide best pre-trained checkpoints of RDM and their sampler settings for reproducing performance:

  • CelebA-HQ $256\times 256$:

    Download checkpoints of first stage and second stage, place them in ckpts/, generate samples and their activations by commands:

    torchrun --standalone --nproc_per_node=8 generate_celebahq.py --outdir=generations/celebahq_samples/ \
        --network_first=ckpts/celebahq_first_stage.pt \
        --network_second=ckpts/celebahq_second_stage.pt
    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=generations/celebahq_samples/ --dest=eval-refs/celebahq_act_sample.npz 
    

    Generate activation data from CelebA-HQ zip or download our version from here:

    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=datasets/celebahq-256x256.zip --dest=eval-refs/celebahq_act_ref.npz 
    

    Calculate metrics by command:

    python evaluate.py calc -m fid -m pr \
        --activations_sample=eval-refs/celebahq_act_sample.npz \
        --activations_ref=eval-refs/celebahq_act_ref.npz
    
  • ImageNet $256\times 256$:

    Download checkpoints of first stage and second stage, place them in ckpts/, generate samples and their activations by commands:

    torchrun --standalone --nproc_per_node=8 generate_imagenet.py --outdir=generations/imagenet_samples/ \
        --network_first=ckpts/imagenet_first_stage.pkl \
        --network_second=ckpts/imagenet_second_stage.pt
    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=generations/imagenet_samples/ --dest=eval-refs/imagenet_act_sample.npz 
    

    Generate activation data from ImageNet zip:

    torchrun --standalone --nproc_per_node=1 evaluate.py activations \
        --data=datasets/imagenet-256x256.zip --dest=eval-refs/imagenet_act_ref.npz 
    

    Calculate FID, sFID and IS by command:

    python evaluate.py calc -m fid -m sfid -m is \
        --activations_sample=eval-refs/imagenet_act_sample.npz \
        --activations_ref=eval-refs/imagenet_act_ref.npz
    

    For the calculation of Precision and Recall on ImageNet, we follow ADM to use 1w reference samples. You can download the activation data we produced from here. Then run the following command:

    python evaluate.py calc -m pr \
        --activations_sample=eval-refs/imagenet_act_sample.npz \
        --activations_ref=eval-refs/imagenet_act_1w_ref.npz
    

Training

you can follow the instruction of EDM to train a new model of the first stage (standard diffusion). Using ImageNet for example, run command:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/imagenet-64x64.zip --eff-attn=True \
	--cond=1 --batch=4096  --batch-gpu=32 --lr=1e-4 --ema=50 --dropout=0.1 --fp16=1 --ls=25 \
	--arch=adm --precond=edm

If you want to train a second stage model (blurring diffusion), set argument --precond=blur and other arguments for the configuration of blurring diffusion. The command will be:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/imagenet-256x256.zip --eff-attn=True \
	--cond=1 --batch=4096  --batch-gpu=8 --lr=1e-4 --dropout=0.1 --fp16=1 --ls=1 \
	--arch=adm --precond=blur --up-scale=4 --block-scale=0.15 --prob-length=0.93 --blur-sigma-max=3.0

As for CelebA-HQ, train a first stage model with:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/CelebA-HQ-64x64.zip --eff-attn=True \
	--cond=0 --batch=1024  --batch-gpu=32 --lr=1e-4 --dropout=0.15 --augment=0.2 --ls=1 \
	--arch=adm --precond=edm

And for training a second stage model:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --data=datasets/CelebA-HQ-256x256.zip --eff-attn=True \
	--cond=0 --batch=1024  --batch-gpu=8 --lr=1e-4 --dropout=0.2 --augment=0.2 --fp16=1 --ls=1 \
	--arch=adm --precond=blur --up-scale=4 --block-scale=0.15 --prob-length=0.89 --blur-sigma-max=2.0

Citation

@article{teng2023relay,
  title={Relay Diffusion: Unifying diffusion process across resolutions for image synthesis},
  author={Teng, Jiayan and Zheng, Wendi and Ding, Ming and Hong, Wenyi and Wangni, Jianqiao and Yang, Zhuoyi and Tang, Jie},
  journal={arXiv preprint arXiv:2309.03350},
  year={2023}
}

Acknowledgements

This implementation is based on https://github.com/NVlabs/edm (codebase of EDM). Thanks a lot!

More Repositories

1

ChatGLM-6B

ChatGLM-6B: An Open Bilingual Dialogue Language Model | 开源双语对话语言模型
Python
40,459
star
2

ChatGLM2-6B

ChatGLM2-6B: An Open Bilingual Chat LLM | 开源双语对话语言模型
Python
15,702
star
3

ChatGLM3

ChatGLM3 series: Open Bilingual Chat LLMs | 开源双语对话语言模型
Python
13,366
star
4

CodeGeeX

CodeGeeX: An Open Multilingual Code Generation Model (KDD 2023)
Python
8,150
star
5

CogVideo

text and image to video generation: CogVideoX (2024) and CogVideo (ICLR 2023)
Python
7,976
star
6

GLM-130B

GLM-130B: An Open Bilingual Pre-Trained Model (ICLR 2023)
Python
7,653
star
7

CodeGeeX2

CodeGeeX2: A More Powerful Multilingual Code Generation Model
Python
7,622
star
8

CogVLM

a state-of-the-art-level open visual language model | 多模态预训练模型
Python
5,913
star
9

GLM-4

GLM-4 series: Open Multilingual Multimodal Chat LMs | 开源多语言多模态对话模型
Python
4,826
star
10

VisualGLM-6B

Chinese and English multimodal conversational language model | 多模态中英双语对话语言模型
Python
4,076
star
11

GLM

GLM (General Language Model)
Python
3,168
star
12

AgentBench

A Comprehensive Benchmark to Evaluate LLMs as Agents (ICLR'24)
Python
2,144
star
13

CogVLM2

GPT4V-level open-source multi-modal model based on Llama3-8B
Python
2,018
star
14

P-tuning-v2

An optimized deep prompt tuning strategy comparable to fine-tuning across scales and tasks
Python
1,968
star
15

CogDL

CogDL: A Comprehensive Library for Graph Deep Learning (WWW 2023)
Python
1,720
star
16

CogView

Text-to-Image generation. The repo for NeurIPS 2021 paper "CogView: Mastering Text-to-Image Generation via Transformers".
Python
1,691
star
17

WebGLM

WebGLM: An Efficient Web-enhanced Question Answering System (KDD 2023)
Python
1,557
star
18

AgentTuning

AgentTuning: Enabling Generalized Agent Abilities for LLMs
Python
1,339
star
19

CodeGeeX4

CodeGeeX4-ALL-9B, a versatile model for all AI software development scenarios, including code completion, code interpreter, web search, function calling, repository-level Q&A and much more.
Python
1,271
star
20

ImageReward

[NeurIPS 2023] ImageReward: Learning and Evaluating Human Preferences for Text-to-image Generation
Python
1,117
star
21

LongWriter

LongWriter: Unleashing 10,000+ Word Generation from Long Context LLMs
Python
1,076
star
22

SwissArmyTransformer

SwissArmyTransformer is a flexible and powerful library to develop your own Transformer variants.
Python
966
star
23

CogView2

official code repo for paper "CogView2: Faster and Better Text-to-Image Generation via Hierarchical Transformers"
Python
944
star
24

P-tuning

A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''.
Python
915
star
25

LongBench

[ACL 2024] LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding
Python
629
star
26

AutoWebGLM

An LLM-based Web Navigating Agent (KDD'24)
Python
584
star
27

GATNE

Source code and dataset for KDD 2019 paper "Representation Learning for Attributed Multiplex Heterogeneous Network"
Python
522
star
28

GraphMAE

GraphMAE: Self-Supervised Masked Graph Autoencoders in KDD'22
Python
462
star
29

CogQA

Source code and dataset for ACL 2019 paper "Cognitive Graph for Multi-Hop Reading Comprehension at Scale"
Python
456
star
30

Inf-DiT

Official implementation of Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer
Python
366
star
31

GCC

GCC: Graph Contrastive Coding for Graph Neural Network Pre-Training @ KDD 2020
Python
322
star
32

MathGLM

Official Pytorch Implementation for MathGLM
Python
316
star
33

HGB

Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks.
Python
301
star
34

AlignBench

大模型多维度中文对齐评测基准 (ACL 2024)
Python
295
star
35

ComiRec

Source code and dataset for KDD 2020 paper "Controllable Multi-Interest Framework for Recommendation"
Python
278
star
36

LongCite

LongCite: Enabling LLMs to Generate Fine-grained Citations in Long-context QA
Python
272
star
37

KOBE

Towards Knowledge-Based Personalized Product Description Generation in E-commerce @ KDD 2019
Python
237
star
38

NLP4Rec-Papers

Paper list of NLP for recommender systems
225
star
39

ProNE

Source code and dataset for IJCAI 2019 paper "ProNE: Fast and Scalable Network Representation Learning"
Python
225
star
40

Chinese-Transformer-XL

Python
218
star
41

GRAND

Source code and dataset of the NeurIPS 2020 paper "Graph Random Neural Network for Semi-Supervised Learning on Graphs"
Python
203
star
42

LongAlign

[EMNLP 2024] LongAlign: A Recipe for Long Context Alignment of LLMs
Python
199
star
43

icetk

A unified tokenization tool for Images, Chinese and English.
Python
150
star
44

CogCoM

Jupyter Notebook
146
star
45

ReST-MCTS

ReST-MCTS*: LLM Self-Training via Process Reward Guided Tree Search (NeurIPS 2024)
Python
146
star
46

KBRD

Towards Knowledge-Based Recommender Dialog System @ EMNLP 2019
Python
134
star
47

GraphMAE2

GraphMAE2: A Decoding-Enhanced Masked Self-Supervised Graph Learner in WWW'23
Python
133
star
48

iPrompt

Code, Data and Demo for Paper: Controllable Generation from Pre-trained Language Models via Inverse Prompting
Python
121
star
49

ProteinLM

Protein Language Model
Python
111
star
50

MCNS

Source code and dataset for KDD 2020 paper "Understanding Negative Sampling in Graph Representation Learning"
Python
111
star
51

VisualAgentBench

Towards Large Multimodal Models as Visual Foundation Agents
Python
94
star
52

CogView3

text to image to generation: CogView3-Plus and CogView3(ECCV 2024)
Python
93
star
53

grb

Graph Robustness Benchmark: A scalable, unified, modular, and reproducible benchmark for evaluating the adversarial robustness of Graph Machine Learning.
Python
91
star
54

GraphSGAN

Implementation of "GraphSGAN", a GAN-based semi-supervised learning algorithm for graph data.
Python
85
star
55

kgTransformer

kgTransformer: pre-training for reasoning over complex KG queries (KDD 22)
Python
83
star
56

ScenarioMeta

Source code and dataset for KDD 2019 paper "Sequential Scenario-Specific Meta Learner for Online Recommendation"
Python
80
star
57

OAG-BERT

A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)
76
star
58

ChatGLM-Math

Python
75
star
59

CogKR

Source code and dataset for paper "Cognitive Knowledge Graph Reasoning for One-shot Relational Learning"
Python
71
star
60

SelfKG

Codes for WWW2022 accepted paper: SelfKG: Self-Supervised Entity Alignment in Knowledge Graphs
Python
67
star
61

FewNLU

Python
65
star
62

SciGLM

SciGLM: Training Scientific Language Models with Self-Reflective Instruction Annotation and Tuning (NeurIPS D&B Track 2024)
Python
62
star
63

Multilingual-GLM

The multilingual variant of GLM, a general language model trained with autoregressive blank infilling objective
Python
62
star
64

XDAI

Python
61
star
65

CogAgent

59
star
66

OAG

Source code and dataset for KDD 2019 paper "OAG: Toward Linking Large-scale Heterogeneous Entity Graphs"
Python
59
star
67

NaturalCodeBench

Python
54
star
68

LVBench

LVBench: An Extreme Long Video Understanding Benchmark
Python
52
star
69

AutoRE

Python
45
star
70

Graph-Reading-Group

Daily reading group on graphs at KEG
44
star
71

SCR

SCR: Training Graph Neural Networks with Consistency Regularization
Python
37
star
72

WhoIsWho

KDD'23 Web-Scale Academic Name Disambiguation: the WhoIsWho Benchmark, Leaderboard, and Toolkit
Python
34
star
73

FastLDM

Inference speed-up for stable-diffusion (ldm) with TensorRT.
Python
34
star
74

GraphCAD

TKDE'22-GraphCAD: https://arxiv.org/pdf/2108.07516.pdf
Python
30
star
75

GRAND-plus

Code and dataset for paper "GRAND+: Scalable Graph Random Neural Networks"
Python
30
star
76

KDD-Industrial-Papers

A list of recent industrial papers in KDD'16–'18
28
star
77

ApeGNN

ApeGNN: Node-Wise Adaptive Aggregation in GNNs for Recommendation (WWW'23)
Python
23
star
78

GLM-iprompt

Apply Iprompt on GLM with innovative new methods. Currently support Chinese QA, English QA and Chinese poem generation.
Python
21
star
79

GIAAD

Graph Injection Adversarial Attack & Defense Dataset , extracted from KDD CUP 2020 ML2 Track
Python
21
star
80

Tsinghua-ML-Course

Course Materials for ML Course at Tsinghua
HTML
21
star
81

HOSMEL

A task relevant entity linking toolkit
Python
20
star
82

Self-Contrast

Extensive Self-Contrast Enables Feedback-Free Language Model Alignment
Python
19
star
83

RecDCL

RecDCL: Dual Contrastive Learning for Recommendation (WWW'24, Oral)
Python
19
star
84

tdgia

code for paper TDGIA:Effective Injection Attacks on Graph Neural Networks (KDD 2021, research track)
Python
18
star
85

BatchSampler

The source code for BatchSampler that accepted in KDD'23
Python
18
star
86

MRT

MRT: Tracing the Evolution of Scientific Publications (TKDE 2021)
16
star
87

LargeScale

Python
15
star
88

eTrust

Source code and dataset for TKDE 2019 paper “Trust Relationship Prediction in Alibaba E-Commerce Platform”
C++
15
star
89

MSAGPT

MSAGPT
Python
15
star
90

whoiswho-top-solutions

Python
14
star
91

paper-source-trace

Python
14
star
92

Efficient-Head-Finetuning

Source code for EMNLP2022 long paper: Parameter-Efficient Tuning Makes a Good Classification Head
Python
13
star
93

IGB

Source code and dataset for IJCAI 2022 paper "Rethinking the Setting of Semi-supervised Learning on Graphs"
Python
10
star
94

BattleAgentBench

Python
9
star
95

GraphAlign

GraphAlign: Pretraining One Graph Neural Network on Multiple Graphs via Feature Alignment
Python
8
star
96

APAR

APAR: LLMs Can Do Auto-Parallel Auto-Regressive Decoding
Python
8
star
97

scholar-profiling

Jupyter Notebook
7
star
98

citation-prediction

Python
7
star
99

OpenWebAgent

A convenient framework for developing LLM- and LMM-based web agents.
JavaScript
6
star
100

OAG-AQA

Python
6
star