EqGAN-SA: Improving GAN Equilibrium by Raising Spatial Awareness
Improving GAN Equilibrium by Raising Spatial Awareness
Jianyuan Wang, Ceyuan Yang, Yinghao Xu, Yujun Shen, Hongdong Li, Bolei Zhou
CVPR 2022
[Paper] [Project Page] [Demo]
In Generative Adversarial Networks (GANs), a generator (G) and a discriminator (D) are expected to reach a certain equilibrium where D cannot distinguish the generated images from the real ones. However, in practice it is difficult to achieve such an equilibrium in GAN training, instead, D almost always surpasses G. We attribute this phenomenon to the information asymmetry that D learns its own visual attention when determining whether an image is real or fake, but G has no explicit clue on which regions to focus on.
To alleviate the issue of D dominating the competition in GANs, we aim to raise the spatial awareness of G. We encode randomly sampled multi-level heatmaps into the intermediate layers of G as an inductive bias. We further propose to align the spatial awareness of G with the attention map induced from D. Through this way we effectively lessen the information gap between D and G. Extensive results show that our method pushes the two-player game in GANs closer to the equilibrium, leading to a better synthesis performance. As a byproduct, the introduced spatial awareness facilitates interactive editing over the output synthesis.
Environment
This work was developed on the codebase styleGAN2-ada-pytorch. Please follow its requirement as below:
- Linux and Windows are supported, but Linux is recommended for performance and compatibility reasons.
- The original codebase used CUDA toolkit 11.0 and PyTorch 1.7.1. Our experiments were conducted by CUDA toolkit 9.0 and PyTorch 1.8.1. Both the settings are acceptable but may observe a performance difference. Please also install torchvision along with pytorch.
- Python libraries:
pip install click requests tqdm pyspng ninja psutil scipy imageio-ffmpeg==0.4.3
.
The code relies heavily on custom PyTorch extensions that are compiled on the fly using NVCC. On Windows, the compilation requires Microsoft Visual Studio. We recommend installing Visual Studio Community Edition and adding it into PATH
using "C:\Program Files (x86)\Microsoft Visual Studio\<VERSION>\Community\VC\Auxiliary\Build\vcvars64.bat"
.
Dataset Preparation
Please refer to the original page for data processing for details.
All the datasets are stored as uncompressed ZIP archives containing uncompressed PNG files and a metadata file dataset.json
for labels. Please see dataset_tool.py
for more information. Alternatively, the folder can also be used directly as a dataset, without running it through dataset_tool.py
first, but doing so may lead to suboptimal performance.
FFHQ:
Step 1: Download the Flickr-Faces-HQ dataset as TFRecords.
Step 2: Extract images from TFRecords using dataset_tool.py
from the TensorFlow version of StyleGAN2-ADA:
# Using dataset_tool.py from TensorFlow version at
# https://github.com/NVlabs/stylegan2-ada/
python ../stylegan2-ada/dataset_tool.py unpack \
--tfrecord_dir=~/ffhq-dataset/tfrecords/ffhq --output_dir=/tmp/ffhq-unpacked
Step 3: Create ZIP archive using dataset_tool.py
from this repository:
# Scaled down 256x256 resolution.
python dataset_tool.py --source=/tmp/ffhq-unpacked --dest=~/datasets/ffhq256x256.zip \
--width=256 --height=256
LSUN: Download the desired categories from the LSUN project page and convert to ZIP archive:
python dataset_tool.py --source=~/downloads/lsun/raw/cat_lmdb --dest=~/datasets/lsuncat200k.zip \
--transform=center-crop --width=256 --height=256 --max_images=200000
Training
Taking the LSUN Cat dataset as an example:
python ./train.py --outdir=/runs --data=/data/lsuncat200k.zip --gpus=8 --cfg=paper256 \
--aug=noaug --pl_w=0 --close_style_mixing=True \
--use_sel=True --align_loss=True
The flag --use_sel
indicates using the spatial encoding layer or not, while --align_loss
determines whether using the alignment loss.
You may replace --data
by the paths of other datasets. We set --aug
to noaug
to disable the ADA augmentation, i.e., switching to StyleGAN2 instead of StyleGAN2-ADA. We close the path length regularization and style mixing because they have little effect on our method.
Evaluation
During training, train.py automatically computes FID for each network pickle. To measure the synthesis quality of a pretrained model, you can specify the metric, data path, network pkl, and other settings for calc_metrics.py, like:
python calc_metrics.py --metrics=fid50k_full --data=data/lsuncat200k.zip --network=ckpt/cat.pkl
You can also generate some samples by:
python generate.py --outdir=out --trunc=1 --seeds=85,265,297 --network=ckpt/cat.pkl
Pretrained Models
The models for the LSUN Cat, LSUN Bedroom, and FFHQ dataset have been available (25M training iterations). The model for LSUN Bedroom was not discussed in the paper, while we provide it here to show our idea also works for indoor scenes.
The released code is slightly different from the version during submission. For example, the synthesis quality of the provided LSUN Cat model is a bit better than the result we reported in the paper, e.g., 6.62 vs 6.81 on LSUN Cat.
Model | FID | Link |
---|---|---|
LSUN Cat | 6.62 | link |
LSUN Bedroom | 2.95 | link |
FFHQ | 2.89 | link |
LSUN Church | 3.02 | link |
TODO
β Training Code
β Training Script
β Check the Code
β Pretrained Model
- User Interface
Acknowledgement
Thanks Janne Hellsten and Tero Karras for their excellent codebase styleGAN2-ada-pytorch.
BibTeX
@InProceedings{Wang_2022_CVPR,
author = {Wang, Jianyuan and Yang, Ceyuan and Xu, Yinghao and Shen, Yujun and Li, Hongdong and Zhou, Bolei},
title = {Improving GAN Equilibrium by Raising Spatial Awareness},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {11285-11293}
}