Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data (NeurIPS 2021)
This repository provides the official PyTorch implementation for the following paper:
Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data
Liming Jiang, Bo Dai, Wayne Wu and Chen Change Loy
In NeurIPS 2021.
Project Page | Paper | Poster | Slides | YouTube Demo
Abstract: Generative adversarial networks (GANs) typically require ample data for training in order to synthesize high-fidelity images. Recent studies have shown that training GANs with limited data remains formidable due to discriminator overfitting, the underlying cause that impedes the generator's convergence. This paper introduces a novel strategy called Adaptive Pseudo Augmentation (APA) to encourage healthy competition between the generator and the discriminator. As an alternative method to existing approaches that rely on standard data augmentations or model regularization, APA alleviates overfitting by employing the generator itself to augment the real data distribution with generated images, which deceives the discriminator adaptively. Extensive experiments demonstrate the effectiveness of APA in improving synthesis quality in the low-data regime. We provide a theoretical analysis to examine the convergence and rationality of our new training strategy. APA is simple and effective. It can be added seamlessly to powerful contemporary GANs, such as StyleGAN2, with negligible computational cost.
convergence_demo.mp4
Updates
-
[11/2021] The code of APA is released.
-
[09/2021] The paper of APA is accepted by NeurIPS 2021.
Requirements
- 1–8 high-end NVIDIA GPUs with at least 12 GB of memory. We have done all testing and development using 8 NVIDIA Tesla V100 PCIe 32 GB GPUs.
- CUDA toolkit 10.1 or later. Use at least version 11.1 if running on RTX 3090. We use CUDA toolkit 10.1.
- 64-bit Python 3.7 and PyTorch 1.7.1 with compatible CUDA toolkit. See https://pytorch.org/ for PyTorch install instructions. Using Anaconda to create a new Python virtual environment is recommended.
- Run
pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3 psutil scipy tensorboard
.
Inference for Generating Images
Pretrained models can be downloaded from Google Drive:
Model | Description | FID |
---|---|---|
afhqcat5k256x256-apa.pkl | AFHQ-Cat-5k (limited itself, 256x256), trained from scratch using APA | 4.876 |
ffhq5k256x256-apa.pkl | FFHQ-5k (~7% data, 256x256), trained from scratch using APA | 13.249 |
anime5k256x256-apa.pkl | Anime-5k (~2% data, 256x256), trained from scratch using APA | 13.089 |
cub12k256x256-apa.pkl | CUB-12k (limited itself, 256x256), trained from scratch using APA | 12.889 |
ffhq70kfull256x256-apa.pkl | FFHQ-70k (full data, 256x256), trained from scratch using APA | 3.678 |
ffhq5k1024x1024-apa.pkl | FFHQ-5k (~7% data, 1024x1024), trained from scratch using APA | 9.545 |
The downloaded models are stored as *.pkl
files that can be referenced using local filenames:
# Generate images with the truncation of 0.7
python generate.py --outdir=out --trunc=0.7 --seeds=1000-1199 --network=/path/to/checkpoint/pkl
# Generate images without truncation
python generate.py --outdir=out --trunc=1 --seeds=1000-1199 --network=/path/to/checkpoint/pkl
Outputs from the above commands will be placed under out/*.png
, controlled by --outdir
.
Dataset Preparation
Our used datasets can be downloaded from their official pages:
Datasets | Animal Faces-HQ Cat (AFHQ-Cat) | Flickr-Faces-HQ (FFHQ) | Danbooru2019 Portraits (Anime) | Caltech-UCSD Birds-200-2011 (CUB) |
---|
We use dataset_tool.py to prepare the downloaded datasets (run python dataset_tool.py --help
for more information). The datasets will be stored as uncompressed ZIP archives containing uncompressed PNG files. Alternatively, a folder containing images can also be used directly as a dataset, but doing so may lead to suboptimal performance.
For instance, the ZIP archive (a subset of 5k images with a resolution of 256 × 256) of a custom dataset can be created from its folder containing images:
python dataset_tool.py --source=/path/to/image/folder --dest=/path/to/archive.zip \
--width=256 --height=256 --max-images=5000
More detailed steps can be found at stylegan2-ada-pytorch Preparing datasets.
Training New Networks
To train a new model using the proposed APA, we recommend running the following command as a starting point to achieve desirable quality in most cases:
python train.py --outdir=./experiments --gpus=8 --data=/path/to/mydataset.zip \
--metricdata=/path/to/mydatasetfull.zip --mirror=1 \
--cfg=auto --aug=apa --with-dataaug=true
In this example, the results are saved to a newly created directory ./experiments/<ID>-mydataset-mirror-auto8-apa-wdataaug
, controlled by --outdir
. The auto8
indicates the base configuration that the hyperparameters were selected automatically for training on 8 GPUs. The training exports network pickles (network-snapshot-<INT>.pkl
) and example images (fakes<INT>.png
) at regular intervals (controlled by --snap
).
For each pickle, the training also evaluates FID (controlled by --metrics
) and logs the resulting scores in metric-fid50k_full.jsonl
(as well as TFEvents if TensorBoard is installed). Following stylegan2-ada, it is noteworthy that when trained with artifically limited/amplified datasets, the quality metrics (e.g., fid50k_full
) should still be evaluated against the corresponding original full datasets. We add this missing feature in stylegan2-ada-pytorch with a --metricdata
argument to specify a separate metric dataset, which can differ from the training dataset (specified by --data
).
This training does not necessarily lead to the optimal results, which can be further customized with additional command line options:
--cfg
(Default:auto
) can be changed to other training configurations, e.g.,paper256
for the 256x256 resolution andstylegan2
for the 1024x1024 resolution.--aug
(Default:apa
) specifies the augmentation mode, which can be adjusted tonoaug
for the no augmentation mode on sufficient data orfixed
for a fixed deception probability (controlled by--p
).--with-dataaug
(Default:false
) controls whether to apply standard data augmentations for the discriminator inputs. This option can be set tofalse
if one would like to train a model by applying APA solely, which is also effective and with negligible computational cost. Setting it totrue
(following the command above) is sometimes more desirable since APA is complementary to standard data augmentations, which is very important to boost the performance further in most cases.--target
(Default: 0.6) indicates the threshold for APA heuristics. Empirically, a smaller value can be chosen when one has fewer data. Besides, a larger value (i.e.,--target=0.8
) is used for the Anime dataset.--gamma
(Default: depends on--cfg
) overrides R1 gamma. Different values can be tried for a new dataset.
Please refer to python train.py --help
and stylegan2-ada-pytorch Training new networks for other options.
Evaluation Metrics
By default, train.py automatically computes FID for each network pickle exported during training. We recommend inspecting metric-fid50k_full.jsonl
(or TensorBoard) at regular intervals to monitor the training progress.
The metrics can also be computed after the training:
python calc_metrics.py --network=/path/to/checkpoint/pkl --gpus=8 \
--metrics=fid50k_full,is50k --metricdata=/path/to/mydatasetfull.zip --mirror=1
The command above calculates the fid50k_full
and is50k
metrics for a specified checkpoint pickle file (run python calc_metrics.py --help
and refer to stylegan2-ada-pytorch Quality metrics for more information). Similarly, the metrics should be evaluated against the corresponding original full dataset.
Some metrics may have a high one-off cost when calculating them for the first time on a new dataset. Also note that the evaluation is done using a different random seed each time, so the results could slightly vary if the same metric is computed multiple times.
Additional Features
Please refer to stylegan2-ada-pytorch for other usage and statistics of the codebase. Differently, the training cost of applying APA solely is negligible as opposed to ADA that spends additional time for applying external augmentations (see our paper for details).
Results
Effectiveness on Various Datasets
Effectiveness Given Different Data Amounts
Overfitting and Convergence Analysis
Comparison with Other State-of-the-Art Solutions
Higher-Resolution Examples (1024 × 1024) on FFHQ-5k (~7% data)
Citation
If you find this work useful for your research, please cite our paper:
@inproceedings{jiang2021DeceiveD,
title={{Deceive D: Adaptive Pseudo Augmentation} for {GAN} Training with Limited Data},
author={Jiang, Liming and Dai, Bo and Wu, Wayne and Loy, Chen Change},
booktitle={NeurIPS},
year={2021}
}
Acknowledgments
The code is developed based on stylegan2-ada-pytorch. We appreciate the nice PyTorch implementation.
License
Copyright (c) 2021. All rights reserved.
The code is released under the NVIDIA Source Code License.