• Stars
    star
    193
  • Rank 201,081 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 2 years ago
  • Updated 6 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Effective Data Augmentation With Diffusion Models

Effective Data Augmentation With Diffusion Models

DA-Fusion

Existing data augmentations like rotations and re-colorizations provide diversity but preserve semantics. We explore how prompt-based generative models complement existing data augmentations by controlling image semantics via prompts. Our generative data augmentations build on Stable Diffusion and improve visual few-shot learning.

Preprint

Installation

To install the package, first create a conda environment.

conda create -n da-fusion python=3.7 pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.6 -c pytorch
conda activate da-fusion
pip install diffusers["torch"] transformers pycocotools pandas matplotlib seaborn scipy

Then download and install the source code.

git clone [email protected]:brandontrabucco/da-fusion.git
pip install -e da-fusion

Datasets

We benchmark DA-Fusion on few-shot image classification problems, including a Leafy Spurge weed recognition task, and classification tasks derived from COCO and PASCAL VOC. For the latter two, we label images with the classes corresponding to the largest object in the image.

Custom datasets can be evaluated by implementing subclasses of semantic_aug/few_shot_dataset.py.

Setting Up PASCAL VOC

Data for the PASCAL VOC task is adapted from the 2012 PASCAL VOC Challenge. Once this dataset has been downloaded and extracted, the PASCAL dataset class semantic_aug/datasets/pascal.py should be pointed to the downloaded dataset via the PASCAL_DIR config variable located here.

Ensure that PASCAL_DIR points to a folder containing ImageSets, JPEGImages, SegmentationClass, and SegmentationObject subfolders.

Setting Up COCO

To setup COCO, first download the 2017 Training Images, the 2017 Validation Images, and the 2017 Train/Val Annotations. These files should be unzipped into the following directory structure.

coco2017/
    train2017/
    val2017/
    annotations/

COCO_DIR located here should be updated to point to the location of coco2017 on your system.

Setting Up The Spurge Dataset

We are planning to release this dataset in the next few months. Check back for updates!

Fine-Tuning Tokens

We perform textual inversion (https://arxiv.org/abs/2208.01618) to adapt Stable Diffusion to the classes present in our few-shot datasets. The implementation in fine_tune.py is adapted from the Diffusers example.

We wrap this script for distributing experiments on a slurm cluster in a set of sbatch scripts located at scripts/fine_tuning. These scripts will perform multiple runs of Textual Inversion in parallel, subject to the number of available nodes on your slurm cluster.

If sbatch is not available in your system, you can run these scripts with bash and manually set SLURM_ARRAY_TASK_ID and SLURM_ARRAY_TASK_COUNT for each parallel job (these are normally set automatically by slurm to control the job index, and the number of jobs respectively, and can be set to 0, 1).

Few-Shot Classification

Code for training image classification models using augmented images from DA-Fusion is located in train_classifier.py. This script accepts a number of arguments that control how the classifier is trained:

python train_classifier.py --logdir pascal-baselines/textual-inversion-0.5 \
--synthetic-dir "aug/textual-inversion-0.5/{dataset}-{seed}-{examples_per_class}" \
--dataset pascal --prompt "a photo of a {name}" \
--aug textual-inversion --guidance-scale 7.5 \
--strength 0.5 --mask 0 --inverted 0 \
--num-synthetic 10 --synthetic-probability 0.5 \
--num-trials 1 --examples-per-class 4

This example will train a classifier on the PASCAL VOC task, with 4 images per class, using the prompt "a photo of a ClassX" where the special token ClassX is fine-tuned (from scratch) with textual inversion. Slurm scripts that reproduce the paper are located in scripts/textual_inversion. Results are logged to .csv files based on the script argument --logdir.

We used a custom plotting script to generate the figures in the main paper.

Citation

If you find our method helpful, consider citing our preprint!

@misc{https://doi.org/10.48550/arxiv.2302.07944,
  doi = {10.48550/ARXIV.2302.07944},
  url = {https://arxiv.org/abs/2302.07944},
  author = {Trabucco, Brandon and Doherty, Kyle and Gurinas, Max and Salakhutdinov, Ruslan},
  keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {Effective Data Augmentation With Diffusion Models},
  publisher = {arXiv},
  year = {2023},
  copyright = {arXiv.org perpetual, non-exclusive license}
}

More Repositories

1

design-bench

Benchmarks for Model-Based Optimization
Python
80
star
2

design-baselines

Baselines for Model-Based Optimization
Python
49
star
3

lstm-cuda

This is a c++ implementation of an LSTM Neural Network parallelized for a GPU using CUDA
Cuda
22
star
4

vq-vae-2

Vector Quantized Latent Variable Models In TF 2.0
Python
9
star
5

nerf

Neural Radiance Fields in PyTorch
Python
8
star
6

program-gan

We propose an architecture using the recent Generative Adversarial Networks and Hyper Networks to generate syntactically correct and behaviorally useful python source code from erroneous existing code.
Python
7
star
7

mass

A Simple Approach For Visual Room Rearrangement: 3D Mapping & Semantic Search (ICLR 2023)
Python
5
star
8

transformer_capsule_layer

Implements a convolutional capsule layer with a scaled dot product attention mechanism.
Python
4
star
9

detailed_captioning

Implements state-of-the-art image captioning algorithms, including the bottom-up top-down attention paper.
Python
4
star
10

pixelcnn

Conditional Gated Pixel CNN in TensorFlow 2
Python
4
star
11

lstm-python

This is an implementation of the LSTM Neural Network algorithm using Python 3.6.1 and NumPy.
Python
3
star
12

mineral

A minimalist reinforcement learning package for TensorFlow 2.0
Python
3
star
13

bvn

Birkhoff-von Neumann Decomposition with a Greedy Birkhoff Heuristic
Python
3
star
14

scaled_dot_product_attention

Implements the scaled dot product attention mechamism from the transformer. Vaswani, A. et al. https://arxiv.org/pdf/1706.03762.pdf
Python
3
star
15

constrained_beam_search

A tensorflow implementation of constrained beam search that supports arbitrary constraint definitions.
Python
3
star
16

image_caption_machine

A simple ros package for captioning images streamed from a camera.
Python
3
star
17

cider

A pure python cider scorer using numpy, see: https://arxiv.org/abs/1411.5726.
Python
2
star
18

morphing-agents

Collection Of Dynamic Morphology Agents For MuJoCo
Python
2
star
19

playground

Deep Reinforcement Learning
Python
2
star
20

glove

An implementation of a loader for the GloVe word embeddings.
Python
2
star
21

spork

CLI For Experiments On Slurm / Singularity
Python
2
star
22

deepfashion_dataset

TensorFlow build script for the deep fashion dataset, depends on glove embeddings.
Python
2
star
23

hypercomplex-nn

A neural network implemented using a hypercomplex utility library for tensorflow.
Python
2
star
24

off-policy

Off-Policy Reinforcement Learning Algorithms
Python
2
star
25

adversarial_attack

This is an adversarial attack against the inception v3 network on an ImageNet example.
Jupyter Notebook
2
star
26

efficient-hrl

This is a fork of tensorflow/models/research/efficient-hrl with some changes.
Python
2
star
27

hypercomplex-ops

Utility functions for TensorFlow that enable working with differentiable hypercomplex numbers.
C++
2
star
28

best_first

Non-Sequential Decoding Strategies for Image Captioning
Python
1
star
29

autodiff

A simple package for automatically calculating high order jacobian tensors.
Python
1
star
30

skip_cell

A tensorflow implementation of the RNN cell I invented during my internship at Teuscher Lab at Portland State University in 2016.
Python
1
star
31

batch_attention_test

This is a prototype for the attention mechanism used in my research with image captioning.
Jupyter Notebook
1
star
32

gesture

A Java based gesture recognition app that uses a Convolutional LSTM Neural Network
Java
1
star
33

probabilistic_vq

Implements probabilistic vector quantization for training discrete VAEs
Jupyter Notebook
1
star
34

separable_attention

Implements a multi head attention mechanism for images.
Python
1
star
35

im2txt_match

This is an augmentation of the Show and Tell model trained with an image caption pairwise discriminator.
Python
1
star
36

up_down_cell

Implements the bottom-up top-down attention mechanism from Anderson, Peter, et al. https://arxiv.org/abs/1707.07998
Python
1
star
37

wikipedia_dataset

This is a repository using the Wiki Extractor to build and prepare WIKIPEDIA for use in tensorflow.
Python
1
star
38

webgen

A collection of tools for generating synthetic web data.
Python
1
star