• Stars
    star
    655
  • Rank 68,333 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Label-Efficient Semantic Segmentation with Diffusion Models (ICLR'2022)

Label-Efficient Semantic Segmentation with Diffusion Models

ICLR'2022 [Project page]

Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

This code is based on datasetGAN and guided-diffusion.

Note: use --recurse-submodules when clone.

 

Overview

The paper investigates the representations learned by the state-of-the-art DDPMs and shows that they capture high-level semantic information valuable for downstream vision tasks. We design a simple semantic segmentation approach that exploits these representations and outperforms the alternatives in the few-shot operating point.

DDPM-based Segmentation

 

Updates

3/9/2022:

  1. Improved performance of DDPM-based segmentation by changing:
      Diffusion steps: [50,150,250,350] --> [50,150,250];
      UNet blocks: [6,7,8,9] --> [5,6,7,8,12];
  2. Trained a bit better DDPM on FFHQ-256;
  3. Added MAE for comparison.

 

Datasets

The evaluation is performed on 6 collected datasets with a few annotated images in the training set: Bedroom-18, FFHQ-34, Cat-15, Horse-21, CelebA-19 and ADE-Bedroom-30. The number corresponds to the number of semantic classes.

datasets.tar.gz (~47Mb)

 

DDPM

Pretrained DDPMs

The models trained on LSUN are adopted from guided-diffusion. FFHQ-256 is trained by ourselves using the same model parameters as for the LSUN models.

LSUN-Bedroom: lsun_bedroom.pt
FFHQ-256: ffhq.pt (Updated 3/8/2022)
LSUN-Cat: lsun_cat.pt
LSUN-Horse: lsun_horse.pt

Run

  1. Download the datasets:
      bash datasets/download_datasets.sh
  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/ddpm.json
  4. Run: bash scripts/ddpm/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

Note: train_interpreter.sh is RAM consuming since it keeps all training pixel representations in memory. For ex, it requires ~210Gb for 50 training images of 256x256. (See issue)

Pretrained pixel classifiers and test predictions are here.

How to improve the performance

  • Tune for a particular task what diffusion steps and UNet blocks to use.

 

DatasetDDPM

Synthetic datasets

To download DDPM-produced synthetic datasets (50000 samples, ~7Gb) (updated 3/8/2022):
bash synthetic-datasets/ddpm/download_synthetic_dataset.sh <dataset_name>

Run | Option #1

  1. Download the synthetic dataset:
       bash synthetic-datasets/ddpm/download_synthetic_dataset.sh <dataset_name>
  2. Check paths in experiments/<dataset_name>/datasetDDPM.json
  3. Run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>

Run | Option #2

  1. Download the datasets:
       bash datasets/download_datasets.sh

  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh <checkpoint_name>

  3. Check paths in experiments/<dataset_name>/datasetDDPM.json

  4. Train an interpreter on a few DDPM-produced annotated samples:
       bash scripts/datasetDDPM/train_interpreter.sh <dataset_name>

  5. Generate a synthetic dataset:
       bash scripts/datasetDDPM/generate_dataset.sh <dataset_name>
        Please specify the hyperparameters in this script for the available resources.
        On 8xA100 80Gb, it takes about 12 hours to generate 10000 samples.

  6. Run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>
       One needs to specify the path to the generated data. See comments in the script.

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

MAE

Pretrained MAEs

We pretrain MAE models using the official implementation on the LSUN and FFHQ-256 datasets:

LSUN-Bedroom: lsun_bedroom.pth
FFHQ-256: ffhq.pth
LSUN-Cat: lsun_cat.pth
LSUN-Horse: lsun_horse.pth

Training setups:

Dataset Backbone epochs batch-size mask-ratio
LSUN Bedroom ViT-L-8 150 1024 0.75
LSUN Cat ViT-L-8 200 1024 0.75
LSUN Horse ViT-L-8 200 1024 0.75
FFHQ-256 ViT-L-8 400 1024 0.75

Run

  1. Download the datasets:
       bash datasets/download_datasets.sh
  2. Download the MAE checkpoint:
       bash checkpoints/mae/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/mae.json
  4. Run: bash scripts/mae/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

 

SwAV

Pretrained SwAVs

We pretrain SwAV models using the official implementation on the LSUN and FFHQ-256 datasets:

LSUN-Bedroom FFHQ-256 LSUN-Cat LSUN-Horse
SwAV SwAV SwAV SwAV
SwAVw2 SwAVw2 SwAVw2 SwAVw2

Training setups:

Dataset Backbone epochs batch-size multi-crop num-prototypes
LSUN RN50 200 1792 2x256 + 6x108 1000
FFHQ-256 RN50 400 2048 2x224 + 6x96 200
LSUN RN50w2 200 1920 2x256 + 4x108 1000
FFHQ-256 RN50w2 400 2048 2x224 + 4x96 200

Run

  1. Download the datasets:
       bash datasets/download_datasets.sh
  2. Download the SwAV checkpoint:
       bash checkpoints/{swav|swav_w2}/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/{swav|swav_w2}.json
  4. Run: bash scripts/{swav|swav_w2}/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

 

DatasetGAN

Opposed to the official implementation, more recent StyleGAN2(-ADA) models are used.

Synthetic datasets

To download GAN-produced synthetic datasets (50000 samples):

bash synthetic-datasets/gan/download_synthetic_dataset.sh <dataset_name>

Run

Since we almost fully adopt the official implementation, we don't provide our reimplementation here. However, one can still reproduce our results:

  1. Download the synthetic dataset:
      bash synthetic-datasets/gan/download_synthetic_dataset.sh <dataset_name>
  2. Change paths in experiments/<dataset_name>/datasetDDPM.json
  3. Change paths and run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>

Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

Results

  • Performance in terms of mean IoU:
Method Bedroom-28 FFHQ-34 Cat-15 Horse-21 CelebA-19 ADE-Bedroom-30
ALAE 20.0 ± 1.0 48.1 ± 1.3 -- -- 49.7 ± 0.7 15.0 ± 0.5
VDVAE -- 57.3 ± 1.1 -- -- 54.1 ± 1.0 --
GAN Inversion 13.9 ± 0.6 51.7 ± 0.8 21.4 ± 1.7 17.7 ± 0.4 51.5 ± 2.3 11.1 ± 0.2
GAN Encoder 22.4 ± 1.6 53.9 ± 1.3 32.0 ± 1.8 26.7 ± 0.7 53.9 ± 0.8 15.7 ± 0.3
SwAV 41.0 ± 2.3 54.7 ± 1.4 44.1 ± 2.1 51.7 ± 0.5 53.2 ± 1.0 30.3 ± 1.5
SwAVw2 42.4 ± 1.7 56.9 ± 1.3 45.1 ± 2.1 54.0 ± 0.9 52.4 ± 1.3 30.6 ± 1.0
MAE 45.0 ± 2.0 58.8 ± 1.1 52.4 ± 2.3 63.4 ± 1.4 57.8 ± 0.4 31.7 ± 1.8
DatasetGAN 31.3 ± 2.7 57.0 ± 1.0 36.5 ± 2.3 45.4 ± 1.4 -- --
DatasetDDPM 47.9 ± 2.9 56.0 ± 0.9 47.6 ± 1.5 60.8 ± 1.0 -- --
DDPM 49.4 ± 1.9 59.1 ± 1.4 53.7 ± 3.3 65.0 ± 0.8 59.9 ± 1.0 34.6 ± 1.7

 

  • Examples of segmentation masks predicted by the DDPM-based method:
DDPM-based Segmentation

 

Cite

@misc{baranchuk2021labelefficient,
      title={Label-Efficient Semantic Segmentation with Diffusion Models}, 
      author={Dmitry Baranchuk and Ivan Rubachev and Andrey Voynov and Valentin Khrulkov and Artem Babenko},
      year={2021},
      eprint={2112.03126},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

More Repositories

1

rtdl

Research on Tabular Deep Learning: Papers & Packages
Python
864
star
2

tab-ddpm

[ICML 2023] The official implementation of the paper "TabDDPM: Modelling Tabular Data with Diffusion Models"
Python
346
star
3

navigan

Navigating the GAN Parameter Space for Semantic Image Editing
Jupyter Notebook
296
star
4

rtdl-num-embeddings

(NeurIPS 2022) On Embeddings for Numerical Features in Tabular Deep Learning
Python
295
star
5

tabular-dl-tabr

The implementation of "TabR: Unlocking the Power of Retrieval-Augmented Tabular Deep Learning"
Python
231
star
6

rtdl-revisiting-models

(NeurIPS 2021) Revisiting Deep Learning Models for Tabular Data
Python
191
star
7

swarm

Official code for "SWARM Parallelism: Training Large Models Can Be Surprisingly Communication-Efficient"
Python
115
star
8

DeDLOC

Official code for "Distributed Deep Learning in Open Collaborations" (NeurIPS 2021)
Jupyter Notebook
114
star
9

RuLeanALBERT

RuLeanALBERT is a pretrained masked language model for the Russian language that uses a memory-efficient architecture.
Python
89
star
10

heterophilous-graphs

A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress?
Python
86
star
11

invertible-cd

Invertible Consistency Distillation for Text-Guided Image Editing in Around 7 Steps
Python
63
star
12

GBDT-uncertainty

Jupyter Notebook
50
star
13

graph-glove

PyTorch code for the EMNLP 2020 paper "Embedding Words in Non-Vector Space with Unsupervised Graph Learning"
Python
40
star
14

DVAR

Official implementation of "Is This Loss Informative? Faster Text-to-Image Customization by Tracking Objective Dynamics" (NeurIPS 2023)
Python
35
star
15

sparqling-queries

This repo in the implementation of EMNLP'21 paper "SPARQLing Database Queries from Intermediate Question Decompositions" by Irina Saparina, Anton Osokin
Python
34
star
16

moshpit-sgd

"Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices", official implementation
Jupyter Notebook
28
star
17

adaptive-diffusion

[CVPR2024] Adaptive Teacher-Student Collaboration for Text-Conditional Diffusion Models
Python
26
star
18

gan-transfer

Supplementary code for "When, Why, and Which Pretrained GANs Are Useful?" (ICLR'22)
Jupyter Notebook
24
star
19

specexec

Python
15
star
20

btard

Code for the paper "Secure Distributed Training at Scale" (ICML 2022)
Python
14
star
21

structural-graph-shifts

Evaluating Robustness and Uncertainty of Graph Models Under Structural Distributional Shifts (NeurIPS'23)
Python
10
star
22

crosslingual_winograd

"It's All in the Heads" (Findings of ACL 2021), official implementation and data
Python
10
star
23

distill-nf

Code for the paper: Distilling the Knowledge from Conditional Normalizing Flows
Jupyter Notebook
9
star
24

gan_vs_diff_sr

Does Diffusion Beat GAN in Image Super Resolution?
8
star
25

classification-measures

Official implementation and data for 'Good Classification Measures and How to Find Them' (NeurIPS 2021)
Python
7
star
26

tabred

A Benchmark of Tabular Machine Learning in-the-Wild with real-world industry-grade tabular datasets
Python
7
star
27

text-to-img-hypernymy

Official code for "Hypernymy Understanding Evaluation of Text-to-Image Models via WordNet Hierarchy"
Jupyter Notebook
6
star
28

learnable-init

Code for the paper: Discovering Weight Initializers with Meta-Learning
Jupyter Notebook
5
star
29

mind-your-format

Mind Your Format: Towards Consistent Evaluation of In-Context Learning Improvements
Jupyter Notebook
5
star
30

dnar

The implementation of "Discrete Neural Algorithmic Reasoning"
Python
5
star
31

proxy-dirichlet-distillation

Implementation of "Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets" (NeurIPS 2021) and "Uncertainty Estimation in Autoregressive Structured Prediction" (ICLR 2021)
Python
3
star
32

msr

An official repository of "Multi-Sentence Resampling: A Simple Approach to Alleviate Dataset Length Bias and Beam-Search Degradation"
Python
2
star
33

vqdm

Official repository for VQDM:Accurate Compression of Text-to-Image Diffusion Models via Vector Quantization paper
1
star