• Stars
    star
    657
  • Rank 68,589 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Label-Efficient Semantic Segmentation with Diffusion Models (ICLR'2022)

Label-Efficient Semantic Segmentation with Diffusion Models

ICLR'2022 [Project page]

Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

This code is based on datasetGAN and guided-diffusion.

Note: use --recurse-submodules when clone.

 

Overview

The paper investigates the representations learned by the state-of-the-art DDPMs and shows that they capture high-level semantic information valuable for downstream vision tasks. We design a simple semantic segmentation approach that exploits these representations and outperforms the alternatives in the few-shot operating point.

DDPM-based Segmentation

 

Updates

3/9/2022:

  1. Improved performance of DDPM-based segmentation by changing:
      Diffusion steps: [50,150,250,350] --> [50,150,250];
      UNet blocks: [6,7,8,9] --> [5,6,7,8,12];
  2. Trained a bit better DDPM on FFHQ-256;
  3. Added MAE for comparison.

 

Datasets

The evaluation is performed on 6 collected datasets with a few annotated images in the training set: Bedroom-18, FFHQ-34, Cat-15, Horse-21, CelebA-19 and ADE-Bedroom-30. The number corresponds to the number of semantic classes.

datasets.tar.gz (~47Mb)

 

DDPM

Pretrained DDPMs

The models trained on LSUN are adopted from guided-diffusion. FFHQ-256 is trained by ourselves using the same model parameters as for the LSUN models.

LSUN-Bedroom: lsun_bedroom.pt
FFHQ-256: ffhq.pt (Updated 3/8/2022)
LSUN-Cat: lsun_cat.pt
LSUN-Horse: lsun_horse.pt

Run

  1. Download the datasets:
      bash datasets/download_datasets.sh
  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/ddpm.json
  4. Run: bash scripts/ddpm/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

Note: train_interpreter.sh is RAM consuming since it keeps all training pixel representations in memory. For ex, it requires ~210Gb for 50 training images of 256x256. (See issue)

Pretrained pixel classifiers and test predictions are here.

How to improve the performance

  • Tune for a particular task what diffusion steps and UNet blocks to use.

 

DatasetDDPM

Synthetic datasets

To download DDPM-produced synthetic datasets (50000 samples, ~7Gb) (updated 3/8/2022):
bash synthetic-datasets/ddpm/download_synthetic_dataset.sh <dataset_name>

Run | Option #1

  1. Download the synthetic dataset:
       bash synthetic-datasets/ddpm/download_synthetic_dataset.sh <dataset_name>
  2. Check paths in experiments/<dataset_name>/datasetDDPM.json
  3. Run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>

Run | Option #2

  1. Download the datasets:
       bash datasets/download_datasets.sh

  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh <checkpoint_name>

  3. Check paths in experiments/<dataset_name>/datasetDDPM.json

  4. Train an interpreter on a few DDPM-produced annotated samples:
       bash scripts/datasetDDPM/train_interpreter.sh <dataset_name>

  5. Generate a synthetic dataset:
       bash scripts/datasetDDPM/generate_dataset.sh <dataset_name>
        Please specify the hyperparameters in this script for the available resources.
        On 8xA100 80Gb, it takes about 12 hours to generate 10000 samples.

  6. Run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>
       One needs to specify the path to the generated data. See comments in the script.

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

MAE

Pretrained MAEs

We pretrain MAE models using the official implementation on the LSUN and FFHQ-256 datasets:

LSUN-Bedroom: lsun_bedroom.pth
FFHQ-256: ffhq.pth
LSUN-Cat: lsun_cat.pth
LSUN-Horse: lsun_horse.pth

Training setups:

Dataset Backbone epochs batch-size mask-ratio
LSUN Bedroom ViT-L-8 150 1024 0.75
LSUN Cat ViT-L-8 200 1024 0.75
LSUN Horse ViT-L-8 200 1024 0.75
FFHQ-256 ViT-L-8 400 1024 0.75

Run

  1. Download the datasets:
       bash datasets/download_datasets.sh
  2. Download the MAE checkpoint:
       bash checkpoints/mae/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/mae.json
  4. Run: bash scripts/mae/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

 

SwAV

Pretrained SwAVs

We pretrain SwAV models using the official implementation on the LSUN and FFHQ-256 datasets:

LSUN-Bedroom FFHQ-256 LSUN-Cat LSUN-Horse
SwAV SwAV SwAV SwAV
SwAVw2 SwAVw2 SwAVw2 SwAVw2

Training setups:

Dataset Backbone epochs batch-size multi-crop num-prototypes
LSUN RN50 200 1792 2x256 + 6x108 1000
FFHQ-256 RN50 400 2048 2x224 + 6x96 200
LSUN RN50w2 200 1920 2x256 + 4x108 1000
FFHQ-256 RN50w2 400 2048 2x224 + 4x96 200

Run

  1. Download the datasets:
       bash datasets/download_datasets.sh
  2. Download the SwAV checkpoint:
       bash checkpoints/{swav|swav_w2}/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/{swav|swav_w2}.json
  4. Run: bash scripts/{swav|swav_w2}/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

 

DatasetGAN

Opposed to the official implementation, more recent StyleGAN2(-ADA) models are used.

Synthetic datasets

To download GAN-produced synthetic datasets (50000 samples):

bash synthetic-datasets/gan/download_synthetic_dataset.sh <dataset_name>

Run

Since we almost fully adopt the official implementation, we don't provide our reimplementation here. However, one can still reproduce our results:

  1. Download the synthetic dataset:
      bash synthetic-datasets/gan/download_synthetic_dataset.sh <dataset_name>
  2. Change paths in experiments/<dataset_name>/datasetDDPM.json
  3. Change paths and run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>

Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

Results

  • Performance in terms of mean IoU:
Method Bedroom-28 FFHQ-34 Cat-15 Horse-21 CelebA-19 ADE-Bedroom-30
ALAE 20.0 ± 1.0 48.1 ± 1.3 -- -- 49.7 ± 0.7 15.0 ± 0.5
VDVAE -- 57.3 ± 1.1 -- -- 54.1 ± 1.0 --
GAN Inversion 13.9 ± 0.6 51.7 ± 0.8 21.4 ± 1.7 17.7 ± 0.4 51.5 ± 2.3 11.1 ± 0.2
GAN Encoder 22.4 ± 1.6 53.9 ± 1.3 32.0 ± 1.8 26.7 ± 0.7 53.9 ± 0.8 15.7 ± 0.3
SwAV 41.0 ± 2.3 54.7 ± 1.4 44.1 ± 2.1 51.7 ± 0.5 53.2 ± 1.0 30.3 ± 1.5
SwAVw2 42.4 ± 1.7 56.9 ± 1.3 45.1 ± 2.1 54.0 ± 0.9 52.4 ± 1.3 30.6 ± 1.0
MAE 45.0 ± 2.0 58.8 ± 1.1 52.4 ± 2.3 63.4 ± 1.4 57.8 ± 0.4 31.7 ± 1.8
DatasetGAN 31.3 ± 2.7 57.0 ± 1.0 36.5 ± 2.3 45.4 ± 1.4 -- --
DatasetDDPM 47.9 ± 2.9 56.0 ± 0.9 47.6 ± 1.5 60.8 ± 1.0 -- --
DDPM 49.4 ± 1.9 59.1 ± 1.4 53.7 ± 3.3 65.0 ± 0.8 59.9 ± 1.0 34.6 ± 1.7

 

  • Examples of segmentation masks predicted by the DDPM-based method:
DDPM-based Segmentation

 

Cite

@misc{baranchuk2021labelefficient,
      title={Label-Efficient Semantic Segmentation with Diffusion Models}, 
      author={Dmitry Baranchuk and Ivan Rubachev and Andrey Voynov and Valentin Khrulkov and Artem Babenko},
      year={2021},
      eprint={2112.03126},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

More Repositories

1

rtdl

Research on Tabular Deep Learning: Papers & Packages
Python
874
star
2

tab-ddpm

[ICML 2023] The official implementation of the paper "TabDDPM: Modelling Tabular Data with Diffusion Models"
Python
375
star
3

rtdl-num-embeddings

(NeurIPS 2022) On Embeddings for Numerical Features in Tabular Deep Learning
Python
302
star
4

navigan

Navigating the GAN Parameter Space for Semantic Image Editing
Jupyter Notebook
296
star
5

tabular-dl-tabr

The implementation of "TabR: Unlocking the Power of Retrieval-Augmented Tabular Deep Learning"
Python
258
star
6

rtdl-revisiting-models

(NeurIPS 2021) Revisiting Deep Learning Models for Tabular Data
Python
206
star
7

swarm

Official code for "SWARM Parallelism: Training Large Models Can Be Surprisingly Communication-Efficient"
Python
123
star
8

DeDLOC

Official code for "Distributed Deep Learning in Open Collaborations" (NeurIPS 2021)
Jupyter Notebook
115
star
9

RuLeanALBERT

RuLeanALBERT is a pretrained masked language model for the Russian language that uses a memory-efficient architecture.
Python
90
star
10

heterophilous-graphs

A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress?
Python
89
star
11

invertible-cd

[NeurIPS'2024] Invertible Consistency Distillation for Text-Guided Image Editing in Around 7 Steps
Python
82
star
12

GBDT-uncertainty

Jupyter Notebook
51
star
13

graph-glove

PyTorch code for the EMNLP 2020 paper "Embedding Words in Non-Vector Space with Unsupervised Graph Learning"
Python
40
star
14

specexec

Python
38
star
15

tabred

A Benchmark of Tabular Machine Learning in-the-Wild with real-world industry-grade tabular datasets
Python
37
star
16

DVAR

Official implementation of "Is This Loss Informative? Faster Text-to-Image Customization by Tracking Objective Dynamics" (NeurIPS 2023)
Python
36
star
17

sparqling-queries

This repo in the implementation of EMNLP'21 paper "SPARQLing Database Queries from Intermediate Question Decompositions" by Irina Saparina, Anton Osokin
Python
34
star
18

moshpit-sgd

"Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices", official implementation
Jupyter Notebook
28
star
19

adaptive-diffusion

[CVPR'2024] Adaptive Teacher-Student Collaboration for Text-Conditional Diffusion Models
Python
28
star
20

gan-transfer

Supplementary code for "When, Why, and Which Pretrained GANs Are Useful?" (ICLR'22)
Jupyter Notebook
24
star
21

vqdm

Official repository for VQDM:Accurate Compression of Text-to-Image Diffusion Models via Vector Quantization paper
Python
18
star
22

btard

Code for the paper "Secure Distributed Training at Scale" (ICML 2022)
Python
14
star
23

structural-graph-shifts

Evaluating Robustness and Uncertainty of Graph Models Under Structural Distributional Shifts (NeurIPS'23)
Python
11
star
24

crosslingual_winograd

"It's All in the Heads" (Findings of ACL 2021), official implementation and data
Python
10
star
25

gan_vs_diff_sr

Does Diffusion Beat GAN in Image Super Resolution?
10
star
26

distill-nf

Code for the paper: Distilling the Knowledge from Conditional Normalizing Flows
Jupyter Notebook
9
star
27

classification-measures

Official implementation and data for 'Good Classification Measures and How to Find Them' (NeurIPS 2021)
Python
7
star
28

text-to-img-hypernymy

Official code for "Hypernymy Understanding Evaluation of Text-to-Image Models via WordNet Hierarchy"
Jupyter Notebook
6
star
29

tabm

TabM: Advancing Tabular Deep Learning With Parameter-Efficient Ensembling
Python
6
star
30

dnar

The implementation of "Discrete Neural Algorithmic Reasoning"
Python
6
star
31

learnable-init

Code for the paper: Discovering Weight Initializers with Meta-Learning
Jupyter Notebook
5
star
32

mind-your-format

Mind Your Format: Towards Consistent Evaluation of In-Context Learning Improvements
Jupyter Notebook
5
star
33

proxy-dirichlet-distillation

Implementation of "Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets" (NeurIPS 2021) and "Uncertainty Estimation in Autoregressive Structured Prediction" (ICLR 2021)
Python
4
star
34

tabgraphs

A new benchmark of meaningful tabular datasets with known graph structure
Python
3
star
35

msr

An official repository of "Multi-Sentence Resampling: A Simple Approach to Alleviate Dataset Length Bias and Beam-Search Degradation"
Python
2
star