• Stars
    star
    1,581
  • Rank 28,583 (Top 0.6 %)
  • Language
    Python
  • License
    MIT License
  • Created about 5 years ago
  • Updated almost 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Official Implementation of 'Fast AutoAugment' in PyTorch.

Fast AutoAugment (Accepted at NeurIPS 2019)

Official Fast AutoAugment implementation in PyTorch.

  • Fast AutoAugment learns augmentation policies using a more efficient search strategy based on density matching.
  • Fast AutoAugment speeds up the search time by orders of magnitude while maintaining the comparable performances.

Results

CIFAR-10 / 100

Search : 3.5 GPU Hours (1428x faster than AutoAugment), WResNet-40x2 on Reduced CIFAR-10

Model(CIFAR-10) Baseline Cutout AutoAugment Fast AutoAugment
(transfer/direct)
Wide-ResNet-40-2 5.3 4.1 3.7 3.6 / 3.7 Download
Wide-ResNet-28-10 3.9 3.1 2.6 2.7 / 2.7 Download
Shake-Shake(26 2x32d) 3.6 3.0 2.5 2.7 / 2.5 Download
Shake-Shake(26 2x96d) 2.9 2.6 2.0 2.0 / 2.0 Download
Shake-Shake(26 2x112d) 2.8 2.6 1.9 2.0 / 1.9 Download
PyramidNet+ShakeDrop 2.7 2.3 1.5 1.8 / 1.7 Download
Model(CIFAR-100) Baseline Cutout AutoAugment Fast AutoAugment
(transfer/direct)
Wide-ResNet-40-2 26.0 25.2 20.7 20.7 / 20.6 Download
Wide-ResNet-28-10 18.8 18.4 17.1 17.3 / 17.3 Download
Shake-Shake(26 2x96d) 17.1 16.0 14.3 14.9 / 14.6 Download
PyramidNet+ShakeDrop 14.0 12.2 10.7 11.9 / 11.7 Download

ImageNet

Search : 450 GPU Hours (33x faster than AutoAugment), ResNet-50 on Reduced ImageNet

Model Baseline AutoAugment Fast AutoAugment
(Top1/Top5)
ResNet-50 23.7 / 6.9 22.4 / 6.2 22.4 / 6.3 Download
ResNet-200 21.5 / 5.8 20.0 / 5.0 19.4 / 4.7 Download

Notes

  • We evaluated resnet-50 and resnet-200 with resolution of 224 and 320, respectively. According to the original resnet paper, resnet 200 was tested with the resolution of 320. Also our resnet-200 baseline's performance was similar when we use the resolution.
  • But with recent our code clean-up and bugfixes, we've found that the baseline performs similar to the baseline even using 224x224.
  • When we use 224x224, resnet-200 performs 20.0 / 5.2. Download link for the trained model is here.

We have conducted additional experiments with EfficientNet.

Model Baseline AutoAugment Our Baseline(Batch) +Fast AA
B0 23.2 22.7 22.96 22.68

SVHN Test

Search : 1.5 GPU Hours

Baseline AutoAug / Our Fast AutoAugment
Wide-Resnet28x10 1.5 1.1 1.1

Run

We conducted experiments under

  • python 3.6.9
  • pytorch 1.2.0, torchvision 0.4.0, cuda10

Search a augmentation policy

Please read ray's document to construct a proper ray cluster : https://github.com/ray-project/ray, and run search.py with the master's redis address.

$ python search.py -c confs/wresnet40x2_cifar10_b512.yaml --dataroot ... --redis ...

Train a model with found policies

You can train network architectures on CIFAR-10 / 100 and ImageNet with our searched policies.

  • fa_reduced_cifar10 : reduced CIFAR-10(4k images), WResNet-40x2
  • fa_reduced_imagenet : reduced ImageNet(50k images, 120 classes), ResNet-50
$ export PYTHONPATH=$PYTHONPATH:$PWD
$ python FastAutoAugment/train.py -c confs/wresnet40x2_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar10
$ python FastAutoAugment/train.py -c confs/wresnet40x2_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar100
$ python FastAutoAugment/train.py -c confs/wresnet28x10_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar10
$ python FastAutoAugment/train.py -c confs/wresnet28x10_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar100
...
$ python FastAutoAugment/train.py -c confs/resnet50_b512.yaml --aug fa_reduced_imagenet
$ python FastAutoAugment/train.py -c confs/resnet200_b512.yaml --aug fa_reduced_imagenet

By adding --only-eval and --save arguments, you can test trained models without training.

If you want to train with multi-gpu/node, use torch.distributed.launch such as

$ python -m torch.distributed.launch --nproc_per_node={num_gpu_per_node} --nnodes={num_node} --master_addr={master} --master_port={master_port} --node_rank={0,1,2,...,num_node} FastAutoAugment/train.py -c confs/efficientnet_b4.yaml --aug fa_reduced_imagenet

Citation

If you use this code in your research, please cite our paper.

@inproceedings{lim2019fast,
  title={Fast AutoAugment},
  author={Lim, Sungbin and Kim, Ildoo and Kim, Taesup and Kim, Chiheon and Kim, Sungwoong},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2019}
}

Contact for Issues

References & Opensources

We increase the batch size and adapt the learning rate accordingly to boost the training. Otherwise, we set other hyperparameters equal to AutoAugment if possible. For the unknown hyperparameters, we follow values from the original references or we tune them to match baseline performances.

More Repositories

1

pororo

PORORO: Platform Of neuRal mOdels for natuRal language prOcessing
Python
1,252
star
2

nerf-factory

An awesome PyTorch NeRF library
Python
1,239
star
3

coyo-dataset

COYO-700M: Large-scale Image-Text Pair Dataset
Python
1,062
star
4

kogpt

KakaoBrain KoGPT (Korean Generative Pre-trained Transformer)
Python
994
star
5

torchgpipe

A GPipe implementation in PyTorch
Python
776
star
6

karlo

Python
679
star
7

rq-vae-transformer

The official implementation of Autoregressive Image Generation using Residual Quantization (CVPR '22)
Jupyter Notebook
669
star
8

mindall-e

PyTorch implementation of a 1.3B text-to-image generation model trained on 14 million image-text pairs
Python
630
star
9

honeybee

Official implementation of project Honeybee (CVPR 2024)
Python
370
star
10

word2word

Easy-to-use word-to-word translations for 3,564 language pairs.
Python
350
star
11

torchlars

A LARS implementation in PyTorch
Python
326
star
12

g2pm

A Neural Grapheme-to-Phoneme Conversion Package for Mandarin Chinese Based on a New Open Benchmark Dataset
Python
326
star
13

kor-nlu-datasets

KorNLI and KorSTS: New Benchmark Datasets for Korean Natural Language Understanding
283
star
14

trident

A performance library for machine learning applications.
Python
176
star
15

autoclint

A specially designed light version of Fast AutoAugment
Python
170
star
16

sparse-detr

PyTorch Implementation of Sparse DETR
Python
150
star
17

hotr

Official repository for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'21, Oral Presentation)
Python
132
star
18

kortok

The code and models for "An Empirical Study of Tokenization Strategies for Various Korean NLP Tasks" (AACL-IJCNLP 2020)
Python
114
star
19

scrl

PyTorch Implementation of Spatially Consistent Representation Learning(SCRL)
Python
108
star
20

bassl

Python
108
star
21

flame

Official implementation of the paper "FLAME: Free-form Language-based Motion Synthesis & Editing"
Python
103
star
22

tcl

Official implementation of TCL (CVPR 2023)
Python
98
star
23

brain-agent

Brain Agent for Large-Scale and Multi-Task Agent Learning
Python
92
star
24

helo-word

Team Kakao&Brain's Grammatical Error Correction System for the ACL 2019 BEA Shared Task
Python
88
star
25

miro

Official PyTorch implementation of MIRO (ECCV 2022)
Python
82
star
26

jejueo

Jejueo Datasets for Machine Translation and Speech Synthesis
Python
74
star
27

solvent

Python
66
star
28

noc

Jupyter Notebook
44
star
29

cxr-clip

Python
43
star
30

expgan

Python
41
star
31

autowu

Official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th ICML Workshop on AutoML)
Python
39
star
32

nvs-adapter

Python
33
star
33

ginr-ipc

The official implementation of Generalizable Implicit Neural Representations with Instance Pattern Composers(CVPR’23 highlight).
Python
30
star
34

coyo-vit

ViT trained on COYO-Labeled-300M dataset
Python
28
star
35

irm-empirical-study

An Empirical Study of Invariant Risk Minimization
Python
28
star
36

coyo-align

ALIGN trained on COYO-dataset
Python
25
star
37

magvlt

The official implementation of MAGVLT: Masked Generative Vision-and-Language Transformer (CVPR'23)
Python
21
star
38

hqtransformer

Locally Hierarchical Auto-Regressive Modeling for Image Generation (HQ-Transformer)
Jupyter Notebook
21
star
39

CheXGPT

Python
17
star
40

learning-loss-for-tta

"Learning Loss for Test-Time Augmentation (NeurIPS 2020)"
Python
8
star
41

stg

Official implementation of Selective Token Generation (COLING'22)
Jupyter Notebook
8
star
42

leco

Official implementation of LECO (NeurIPS'22)
Python
5
star
43

bc-hyperopt-example

brain cloud hyperopt example (mnist)
Python
3
star