• Stars
    star
    1,122
  • Rank 41,460 (Top 0.9 %)
  • Language
    Python
  • License
    Other
  • Created almost 2 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[ICML'23] StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis

[Project] [PDF] [Video]

This repository contains the training code for our paper "StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis". We do not provide pretrained checkpoints.

by Axel Sauer, Tero Karras, Samuli Laine, Andreas Geiger, Timo Aila

Requirements

  • Use the following commands with Miniconda3 to create and activate your environment:
    conda create --name sgt python=3.9
    conda activate sgt
    conda install pytorch=1.9.1 torchvision==0.10.1 pytorch-cuda=11.6 -c pytorch -c nvidia
    pip install -r requirements.txt
    
  • GCC 7 or later compilers. The recommended GCC version depends on your CUDA version; see for example, CUDA 11.4 system requirements.
  • If you run into problems when setting up the custom CUDA kernels, we refer to the Troubleshooting docs of the StyleGAN3 repo.

Data Preparation

StyleGAN-T can be trained on unconditional and conditional datasets. For small-scale experiments, we recommend zip datasets. When training on datasets with more than 1 million images, we recommend using webdatasets.

Zip Dataset

Zip-Datasets are stored in the same format as in the previous iterations of StyleGAN: uncompressed ZIP archives containing uncompressed PNG files and a metadata file dataset.json for labels. Custom datasets can be created from a folder containing images.

CIFAR-10: Download the CIFAR-10 python version and convert to ZIP archive:

python dataset_tool.py --source downloads/cifar10/cifar-10-python.tar.gz \
    --dest data/cifar10-32x32.zip

FFHQ: Download the Flickr-Faces-HQ dataset as 1024x1024 images and convert to ZIP archive at the same resolution:

python dataset_tool.py --source downloads/ffhq/images1024x1024 \
    --dest data/ffhq1024.zip --resolution 1024x1024

COCO validation set: The COCO validation set is used for tracking zero-shot FID and CLIP score. First, download the COCO meta data. Then, run

python dataset_tool.py --source downloads/captions_val2014.json \
  --dest data/coco_val256.zip --resolution 256x256 --transform center-crop

It is recommend to prepare the dataset zip at the highest possible resolution, e.g. for FFHQ, the zip should contain images with 1024x1024 pixels. When training lower-resolution models, the training script can downsample the images on the fly.

WebDataset

For preparing webdatasets, we used the excellent img2dataset tool. Documentation for downloading different datasets can be found here. For our experiments, we used data from the following sources: CC3M, CC12M, YFFC100m, Redcaps, LAION-aesthetic-6plus.

The joint dataset should have the following structure

joint_dataset/cc3m/0000.tar
joint_dataset/cc3m/0001.tar
...
joint_dataset/laion_6plus/0000.tar
joint_dataset/laion_6plus/0001.tar
...

Training

Training StyleGAN-T with full capacity on MYDATASET.zip at a resolution of 64x64 pixels:

python -m torch.distributed.run --standalone --nproc_per_node 1 train.py \
  --outdir ./training-runs/ --cfg full --data ./data/MYDATASET.zip \
  --img-resolution 64 --batch 128 --batch-gpu 8 --kimg 25000 --metrics fid50k_full
  • The above commands can be parallelized across multiple GPUs by adjusting --nproc_per_node.
  • --batch specifies the overall batch size, --batch-gpu specifies the batch size per GPU. Be aware that --batch-gpu is also a hyperparameter as the discriminator uses (local) BatchNorm; We generally recommend --batch-gpu of 4 or 8. The training loop will automatically accumulate gradients if you use fewer GPUs until the overall batch size is reached.
  • Samples and metrics are saved in outdir. You can inspect METRIC_NAME.json or run tensorboard in training-runs/ to monitor the training progress.
  • The generator will be conditional if the dataset contains text labels; otherwise, it will be unconditional.
  • For a webdataset comprised of different subsets, the data path should point to the joint parent directory: --data path/to/joint_dataset/

To use the same configuration we used for our ablation study, use --cfg lite. If you want direct control over network parameters, use a custom config. E.g., a smaller models which has 1 residual block, a capacity multiplier of 16384 and a maximum channel count of 256, run

python -m torch.distributed.run --standalone --nproc_per_node 1 train.py \
  --outdir ./training-runs/ --data ./data/MYDATASET.zip \
  --img-resolution 64 --batch 128 --batch-gpu 8 --kimg 25000 --metrics fid50k_full \
  --cfg custom --cbase 16384 --cmax 256 --res-blocks 1

For a description of all input arguments, run python train.py --help

Starting from pretrained checkpoints

If you want to use a previously trained model, you can start from a checkpoint by specifying its path adding --resume PATH_TO_NETWORK_PKL. If you want to continue training from where you left off in a previous run, you can also specify the number of images processed in that run using --resume-kimg XXX, where XXX is that number.

Training modes

By default, all layers of the generator are trained and the CLIP text encoder is frozen. If you want to train only the text encoder, provide --train-mode text-encoder.

If you want to do progressive growing, first train a model at 64x64 pixels. Then provide the path to this pretrained network via -resume, the new target resolution via --img-resolution and use --train-mode freeze64 to freeze the blocks of the 64x64 model and only train the high resolution layers. For example:

python -m torch.distributed.run --standalone --nproc_per_node 1 train.py \
  --outdir ./training-runs/ --data ./data/MYDATASET.zip \
  --img-resolution 512 --batch 128 --batch-gpu 4 --kimg 25000 \
  --cfg lite --resume PATH_TO_NETWORK_64 --train-mode freeze64

Generating Samples

To generate samples with a given network, run

python gen_images.py --network PATH_TO_NETWORK_PKL \
 --prompt 'A painting of a fox in the style of starry night.' \
 --truncation 1 --outdir out --seeds 0-29

For a description of all input arguments, run python gen_images.py --help

Quality Metrics

To calculate metrics for a specific network snapshot, run

python calc_metrics.py --metrics METRIC_NAME --network PATH_TO_NETWORK_PKL

Metric computation is only supported on zip datasets, not webdatasets. The zero-shot COCO metrics expect a coco_val256.zip to be present in the same folder as the training dataset. Alternatively, one can explicitely set an environment variable as follows: export COCOPATH=path/to/coco_val256.zip.

To see the available metrics, run python calc_metrics.py --help

License

Copyright © 2023, NVIDIA Corporation. All rights reserved.

This work is made available under the Nvidia Source Code License.

Excempt are the files training/diffaug.py and networks/vit_utils.py which are partially or fully based on third party github repositories. These two files are copyright their respective authors and under their respective licenses; we include the original license and link to the source at the beginning of the files.

Development

This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests.

Citation

@InProceedings{Sauer2023ARXIV,
  author    = {Axel Sauer and Tero Karras and Samuli Laine and Andreas Geiger and Timo Aila},
  title     = {{StyleGAN-T}: Unlocking the Power of {GANs} for Fast Large-Scale Text-to-Image Synthesis},
  journal   = {{arXiv.org}},
  volume    = {abs/2301.09515},
  year      = {2023},
  url       = {https://arxiv.org/abs/2301.09515},
}

More Repositories

1

sdfstudio

A Unified Framework for Surface Reconstruction
Python
1,965
star
2

occupancy_networks

This repository contains the code for the paper "Occupancy Networks - Learning 3D Reconstruction in Function Space"
Python
1,492
star
3

giraffe

This repository contains the code for the CVPR 2021 paper "GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields"
Python
1,227
star
4

mip-splatting

[CVPR'24 Best Student Paper] Mip-Splatting: Alias-free 3D Gaussian Splatting
Python
1,046
star
5

transfuser

[PAMI'23] TransFuser: Imitation with Transformer-Based Sensor Fusion for Autonomous Driving; [CVPR'21] Multi-Modal Fusion Transformer for End-to-End Autonomous Driving
Python
1,023
star
6

stylegan-xl

[SIGGRAPH'22] StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets
Python
939
star
7

projected-gan

[NeurIPS'21] Projected GANs Converge Faster
Python
876
star
8

unimatch

[TPAMI'23] Unifying Flow, Stereo and Depth Estimation
Python
855
star
9

convolutional_occupancy_networks

[ECCV'20] Convolutional Occupancy Networks
Python
792
star
10

differentiable_volumetric_rendering

This repository contains the code for the CVPR 2020 paper "Differentiable Volumetric Rendering: Learning Implicit 3D Representations without 3D Supervision"
Python
782
star
11

gaussian-opacity-fields

[SIGGRAPH Asia'24 & TOG] Gaussian Opacity Fields: Efficient Adaptive Surface Reconstruction in Unbounded Scenes
Python
705
star
12

monosdf

[NeurIPS'22] MonoSDF: Exploring Monocular Geometric Cues for Neural Implicit Surface Reconstruction
Python
563
star
13

shape_as_points

[NeurIPS'21] Shape As Points: A Differentiable Poisson Solver
Python
518
star
14

tuplan_garage

[CoRL'23] Parting with Misconceptions about Learning-based Vehicle Motion Planning
Python
499
star
15

unisurf

[ICCV'21] UNISURF: Unifying Neural Implicit Surfaces and Radiance Fields for Multi-View Reconstruction
Python
418
star
16

graf

Official code release for "GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis"
Jupyter Notebook
398
star
17

kitti360Scripts

This repository contains utility scripts for the KITTI-360 dataset.
Python
385
star
18

neat

[ICCV'21] NEAT: Neural Attention Fields for End-to-End Autonomous Driving
Python
301
star
19

navsim

[NeurIPS 2024] NAVSIM: Data-Driven Non-Reactive Autonomous Vehicle Simulation and Benchmarking
Python
244
star
20

occupancy_flow

This repository contains the code for the ICCV 2019 paper "Occupancy Flow - 4D Reconstruction by Learning Particle Dynamics"
Python
207
star
21

plant

[CoRL'22] PlanT: Explainable Planning Transformers via Object-Level Representations
Python
201
star
22

factor-fields

[SIGGRAPH 2023] We provide a unified formula for neural fields (Factor Fields) and a novel dictionary factorization (Dictionary Fields)
Jupyter Notebook
183
star
23

sledge

[ECCV'24] SLEDGE: Synthesizing Driving Environments with Generative Models and Rule-Based Traffic
Python
151
star
24

voxgraf

Official code release for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids
Python
128
star
25

carla_garage

[ICCV'23] Hidden Biases of End-to-End Driving Models
Python
121
star
26

gta

[ICLR'24] GTA: A Geometry-Aware Attention Mechanism for Multi-view Transformers
Python
121
star
27

texture_fields

This repository contains code for the paper 'Texture Fields: Learning Texture Representations in Function Space'.
Python
115
star
28

kitti360LabelTool

JavaScript
103
star
29

counterfactual_generative_networks

[ICLR'21] Counterfactual Generative Networks
Python
102
star
30

murf

[CVPR'24] MuRF: Multi-Baseline Radiance Fields
Python
84
star
31

king

[ECCV'22] KING: Generating Safety-Critical Driving Scenarios for Robust Imitation via Kinematics Gradients
Python
73
star
32

controllable_image_synthesis

Towards Unsupervised Learning of Generative Models for 3D Controllable Image Synthesis, CVPR 2020
Python
70
star
33

handheld_svbrdf_geometry

On Joint Estimation of Pose, Geometry and svBRDF from a Handheld Scanner, CVPR2020
Python
59
star
34

connecting_the_dots

This repository contains the code for the paper "Connecting the Dots: Learning Representations for Active Monocular Depth Estimation" https://avg.is.tuebingen.mpg.de/publications/riegler2019cvpr
Python
56
star
35

frequency_bias

Official code for "On the Frequency Bias of Generative Models", NeurIPS 2021
Python
45
star
36

good

[ICLR'23] GOOD: Exploring Geometric Cues for Detecting Objects in an Open World
Python
39
star
37

data_aggregation

This repository contains the code for the CVPR 2020 paper "Exploring Data Aggregation in Policy Learning for Vision-based Urban Autonomous Driving"
Python
38
star
38

campari

[3DV'21] CAMPARI: Camera-Aware Decomposed Generative Neural Radiance Fields
Python
29
star
39

akorn

Reproducing code for the work: Artificial Kuramoto Oscillatory Neurons
22
star
40

autonomousvision.github.io

Blog of the Autonomous Vision Group at MPI-IS Tübingen and University of Tübingen.
HTML
19
star
41

hdt

[COLM'24] HDT: Hierarchical Document Transformer
Python
7
star
42

visual_abstractions

6
star
43

slides

Slide repository of the Autonomous Vision Group at MPI-IS Tübingen and University of Tübingen.
CSS
2
star
44

similarity_reconstruction

This code is based on the paper Exploiting Object Similarity in 3D Reconstruction.
C++
1
star
45

slow_flow

This code is based on the paper Slow Flow: Exploiting High-Speed Cameras for Accurate and Diverse Optical Flow Reference Data.
C++
1
star