• Stars
    star
    538
  • Rank 82,538 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created 9 months ago
  • Updated 3 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[CVPR 2024 Highlight] DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models

DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models

[NEW!] DistriFusion is accepted by CVPR 2024! Our code is publicly available!

teaser We introduce DistriFusion, a training-free algorithm to harness multiple GPUs to accelerate diffusion model inference without sacrificing image quality. Naïve Patch (Overview (b)) suffers from the fragmentation issue due to the lack of patch interaction. The presented examples are generated with SDXL using a 50-step Euler sampler at 1280×1920 resolution, and latency is measured on A100 GPUs.

DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models
Muyang Li*, Tianle Cai*, Jiaxin Cao, Qinsheng Zhang, Han Cai, Junjie Bai, Yangqing Jia, Ming-Yu Liu, Kai Li, and Song Han
MIT, Princeton, Lepton AI, and NVIDIA
In CVPR 2024.

Overview

idea (a) Original diffusion model running on a single device. (b) Naïvely splitting the image into 2 patches across 2 GPUs has an evident seam at the boundary due to the absence of interaction across patches. (c) Our DistriFusion employs synchronous communication for patch interaction at the first step. After that, we reuse the activations from the previous step via asynchronous communication. In this way, the communication overhead can be hidden into the computation pipeline.

Performance

Speedups

Measured total latency of DistriFusion with SDXL using a 50-step DDIM sampler for generating a single image across on NVIDIA A100 GPUs. When scaling up the resolution, the GPU devices are better utilized. Remarkably, when generating 3840×3840 images, DistriFusion achieves 1.8×, 3.4× and 6.1× speedups with 2, 4, and 8 A100s, respectively.

Quality

quality Qualitative results of SDXL. FID is computed against the ground-truth images. Our DistriFusion can reduce the latency according to the number of used devices while preserving visual fidelity.

References:

  • Denoising Diffusion Implicit Model (DDIM), Song et al., ICLR 2021
  • Elucidating the Design Space of Diffusion-Based Generative Models, Karras et al., NeurIPS 2022
  • Parallel Sampling of Diffusion Models, Shih et al., NeurIPS 2023
  • SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis, Podell et al., ICLR 2024

Prerequisites

  • Python3
  • NVIDIA GPU + CUDA >= 12.0 and corresponding CuDNN
  • PyTorch >= 2.2.

Getting Started

Installation

After installing PyTorch, you should be able to install distrifuser with PyPI

pip install distrifuser

or via GitHub:

pip install git+https://github.com/mit-han-lab/distrifuser.git

or locally for development

git clone [email protected]:mit-han-lab/distrifuser.git
cd distrifuser
pip install -e .

Usage Example

In scripts/sdxl_example.py, we provide a minimal script for running SDXL with DistriFusion.

import torch

from distrifuser.pipelines import DistriSDXLPipeline
from distrifuser.utils import DistriConfig

distri_config = DistriConfig(height=1024, width=1024, warmup_steps=4)
pipeline = DistriSDXLPipeline.from_pretrained(
    distri_config=distri_config,
    pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0",
    variant="fp16",
    use_safetensors=True,
)

pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
image = pipeline(
    prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    generator=torch.Generator(device="cuda").manual_seed(233),
).images[0]
if distri_config.rank == 0:
    image.save("astronaut.png")

Specifically, our distrifuser shares the same APIs as diffusers and can be used in a similar way. You just need to define a DistriFusion and use our wrapped DistriSDXLPipeline to load the pretrained SDXL model. Then, we can generate the image like the StableDiffusionXLPipeline in diffusers. The running command is

torchrun --nproc_per_node=$N_GPUS scripts/sdxl_example.py

where $N_GPUS is the number GPUs you want to use.

Benchmark

Our benchmark results are using PyTorch 2.2 and diffusers 0.24.0. First, you may need to install some additional dependencies:

pip install git+https://github.com/zhijian-liu/torchprofile datasets torchmetrics dominate clean-fid

COCO Quality

You can use scripts/generate_coco.py to generate images with COCO captions. The command is

torchrun --nproc_per_node=$N_GPUS scripts/generate_coco.py --no_split_batch

where $N_GPUS is the number GPUs you want to use. By default, the generated results will be stored in results/coco. You can also customize it with --output_root. Some additional arguments that you may want to tune:

  • --num_inference_steps: The number of inference steps. We use 50 by default.
  • --guidance_scale: The classifier-free guidance scale. We use 5 by default.
  • --scheduler: The diffusion sampler. We use DDIM sampler by default. You can also use euler for Euler sampler and dpm-solver for DPM solver.
  • --warmup_steps: The number of additional warmup steps (4 by default).
  • --sync_mode: Different GroupNorm synchronization modes. By default, it is using our corrected asynchronous GroupNorm.
  • --parallelism: The parallelism paradigm you use. By default, it is patch parallelism. You can use tensor for tensor parallelism and naive_patch for naïve patch.

After you generate all the images, you can use our script scripts/compute_metrics.py to calculate PSNR, LPIPS and FID. The usage is

python scripts/compute_metrics.py --input_root0 $IMAGE_ROOT0 --input_root1 $IMAGE_ROOT1

where $IMAGE_ROOT0 and $IMAGE_ROOT1 are paths to the image folders you are trying to comparing. If IMAGE_ROOT0 is the ground-truth foler, please add a --is_gt flag for resizing. We also provide a script scripts/dump_coco.py to dump the ground-truth images.

Latency

You can use scripts/run_sdxl.py to benchmark the latency our different methods. The command is

torchrun --nproc_per_node=$N_GPUS scripts/run_sdxl.py --mode benchmark --output_type latent

where $N_GPUS is the number GPUs you want to use. Similar to scripts/generate_coco.py, you can also change some arguments:

  • --num_inference_steps: The number of inference steps. We use 50 by default.
  • --image_size: The generated image size. By default, it is 1024×1024.
  • --no_split_batch: Disable the batch splitting for classifier-free guidance.
  • --warmup_steps: The number of additional warmup steps (4 by default).
  • --sync_mode: Different GroupNorm synchronization modes. By default, it is using our corrected asynchronous GroupNorm.
  • --parallelism: The parallelism paradigm you use. By default, it is patch parallelism. You can use tensor for tensor parallelism and naive_patch for naïve patch.
  • --warmup_times/--test_times: The number of warmup/test runs. By default, they are 5 and 20, respectively.

Citation

If you use this code for your research, please cite our paper.

@inproceedings{li2023distrifusion,
  title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
  author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Liu, Ming-Yu and Li, Kai and Han, Song},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2024}
}

Acknowledgments

Our code is developed based on huggingface/diffusers and lmxyy/sige. We thank torchprofile for MACs measurement, clean-fid for FID computation and Lightning-AI/torchmetrics for PSNR and LPIPS.

We thank Jun-Yan Zhu and Ligeng Zhu for their helpful discussion and valuable feedback. The project is supported by MIT-IBM Watson AI Lab, Amazon, MIT Science Hub, and National Science Foundation.

More Repositories

1

streaming-llm

[ICLR 2024] Efficient Streaming Language Models with Attention Sinks
Python
6,530
star
2

bevfusion

[ICRA'23] BEVFusion: Multi-Task Multi-Sensor Fusion with Unified Bird's-Eye View Representation
Python
2,286
star
3

temporal-shift-module

[ICCV 2019] TSM: Temporal Shift Module for Efficient Video Understanding
Python
2,060
star
4

once-for-all

[ICLR 2020] Once for All: Train One Network and Specialize it for Efficient Deployment
Python
1,866
star
5

llm-awq

AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration
Python
1,687
star
6

proxylessnas

[ICLR 2019] ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware
C++
1,420
star
7

torchquantum

A PyTorch-based framework for Quantum Classical Simulation, Quantum Machine Learning, Quantum Neural Networks, Parameterized Quantum Circuits with support for easy deployments on real quantum computers.
Jupyter Notebook
1,304
star
8

data-efficient-gans

[NeurIPS 2020] Differentiable Augmentation for Data-Efficient GAN Training
Python
1,277
star
9

efficientvit

EfficientViT is a new family of vision models for efficient high-resolution vision.
Python
1,218
star
10

torchsparse

[MICRO'23, MLSys'22] TorchSparse: Efficient Training and Inference Framework for Sparse Convolution on GPUs.
Cuda
1,181
star
11

smoothquant

[ICML 2023] SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models
Python
1,175
star
12

gan-compression

[CVPR 2020] GAN Compression: Efficient Architectures for Interactive Conditional GANs
Python
1,104
star
13

anycost-gan

[CVPR 2021] Anycost GANs for Interactive Image Synthesis and Editing
Python
778
star
14

tinyml

Python
755
star
15

TinyChatEngine

TinyChatEngine: On-Device LLM Inference Library
C++
730
star
16

tinyengine

[NeurIPS 2020] MCUNet: Tiny Deep Learning on IoT Devices; [NeurIPS 2021] MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning; [NeurIPS 2022] MCUNetV3: On-Device Training Under 256KB Memory
C
717
star
17

fastcomposer

[IJCV] FastComposer: Tuning-Free Multi-Subject Image Generation with Localized Attention
Python
644
star
18

pvcnn

[NeurIPS 2019, Spotlight] Point-Voxel CNN for Efficient 3D Deep Learning
Python
639
star
19

lite-transformer

[ICLR 2020] Lite Transformer with Long-Short Range Attention
Python
589
star
20

spvnas

[ECCV 2020] Searching Efficient 3D Architectures with Sparse Point-Voxel Convolution
Python
577
star
21

mcunet

[NeurIPS 2020] MCUNet: Tiny Deep Learning on IoT Devices; [NeurIPS 2021] MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning
Python
460
star
22

tiny-training

On-Device Training Under 256KB Memory [NeurIPS'22]
Python
432
star
23

amc

[ECCV 2018] AMC: AutoML for Model Compression and Acceleration on Mobile Devices
Python
428
star
24

dlg

[NeurIPS 2019] Deep Leakage From Gradients
Python
400
star
25

haq

[CVPR 2019, Oral] HAQ: Hardware-Aware Automated Quantization with Mixed Precision
Python
368
star
26

offsite-tuning

Offsite-Tuning: Transfer Learning without Full Model
Python
365
star
27

hardware-aware-transformers

[ACL'20] HAT: Hardware-Aware Transformers for Efficient Natural Language Processing
Python
321
star
28

litepose

[CVPR'22] Lite Pose: Efficient Architecture Design for 2D Human Pose Estimation
Python
304
star
29

inter-operator-scheduler

[MLSys 2021] IOS: Inter-Operator Scheduler for CNN Acceleration
C++
191
star
30

amc-models

[ECCV 2018] AMC: AutoML for Model Compression and Acceleration on Mobile Devices
Python
166
star
31

apq

[CVPR 2020] APQ: Joint Search for Network Architecture, Pruning and Quantization Policy
Python
156
star
32

parallel-computing-tutorial

C++
134
star
33

flatformer

[CVPR'23] FlatFormer: Flattened Window Attention for Efficient Point Cloud Transformer
Python
119
star
34

patch_conv

Patch convolution to avoid large GPU memory usage of Conv2D
Python
74
star
35

6s965-fall2022

Jupyter Notebook
64
star
36

sparsevit

[CVPR'23] SparseViT: Revisiting Activation Sparsity for Efficient High-Resolution Vision Transformer
Python
48
star
37

bnn-icestick

Binary Neural Network on IceStick FPGA.
Jupyter Notebook
47
star
38

e3d

Efficient 3D Deep Learning
46
star
39

neurips-micronet

[JMLR'20] NeurIPS 2019 MicroNet Challenge Efficient Language Modeling, Champion
Jupyter Notebook
40
star
40

spatten-llm

[HPCA'21] SpAtten: Efficient Sparse Attention Architecture with Cascade Token and Head Pruning
Scala
32
star
41

tinychat-tutorial

C++
28
star
42

pruning-sparsity-publications

14
star
43

iccad-tinyml-open

[ICCAD'22 TinyML Contest] Efficient Heart Stroke Detection on Low-cost Microcontrollers
C
14
star
44

calo-cluster

Jupyter Notebook
5
star
45

ml-blood-pressure

Python
5
star
46

gan-compression-dynamic

Python
3
star
47

data-efficient-gans-dynamic

Python
3
star