• Stars
    star
    135
  • Rank 269,297 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created over 3 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch implementation of VQ-VAE-2 from "Generating Diverse High-Fidelity Images with VQ-VAE-2"

Generating Diverse High-Fidelity Images with VQ-VAE-2 [Work in Progress]

PyTorch implementation of Hierarchical, Vector Quantized, Variational Autoencoders (VQ-VAE-2) from the paper "Generating Diverse High-Fidelity Images with VQ-VAE-2"

Original paper can be found here

Vector Quantizing layer based off implementation by @rosinality found here.

Aiming for a focus on supporting an arbitrary number of VQ-VAE "levels". Most implementations in PyTorch typically only use 2 which is limiting at higher resolutions. This repository contains checkpoints for a 3-level and 5-level VQ-VAE-2, trained on FFHQ1024.

This project will not only contain the VQ-VAE-2 architecture, but also an example autoregressive prior and latent dataset extraction.

This project is very much Work-in-Progress. VQ-VAE-2 model is mostly complete. PixelSnail prior models are still experimental and most definitely do not work.

Usage

VQ-VAE-2 Usage

Run VQ-VAE-2 training using the config task_name found in hps.py. Defaults to cifar10:

python main-vqvae.py --task task_name

Evaluate VQ-VAE-2 from parameters state_dict_path on task task_name. Defaults to cifar10:

python main-vqvae.py --task task_name --load-path state_dict_path --evaluate

Other useful flags:

--no-save       # disables saving of files during training
--cpu           # do not use GPU
--batch-size    # overrides batch size in cfg.py, useful for evaluating on larger batch size
--no-tqdm       # disable tqdm status bars
--no-save       # disables saving of files
--no-amp        # disables using native AMP (Automatic Mixed Precision) operations
--save-jpg      # save all images as jpg instead of png, useful for extreme resolutions

Latent Dataset Generation

Run latent dataset generation using VQ-VAE-2 saved at path that was trained on task task_name. Defaults to cifar10:

python main-latents.py path --task task_name

Result is saved in latent-data directory.

Other useful flags:

--cpu           # do not use GPU
--batch-size    # overrides batch size in cfg.py, useful for evaluating on larger batch size
--no-tqdm       # disable tqdm status bars
--no-save       # disables saving of files
--no-amp        # disables using native AMP (Automatic Mixed Precision) operations

Discrete Prior Usage

Run level level PixelSnail discrete prior training using the config task_name found in hps.py using latent dataset saved at path latent_dataset.pt and VQ-VAE vqvae_path to dequantize conditioning variables. Defaults to cifar10:

python main-pixelsnail.py latent_dataset.pt vqvae_path.pt level --task task_name

Other useful flags:

--cpu           # do not use GPU
--load-path     # resume from saved state on disk
--batch-size    # overrides batch size in cfg.py, useful for evaluating on larger batch size
--save-jpg      # save all images as jpg instead of png, useful for extreme resolutions
--no-tqdm       # disable tqdm status bars
--no-save       # disables saving of files

Sample Generation

Run sampling script on trained VQ-VAE-2 and PixelSnail priors using the config task_name (default cifar10) found in hps.py. The first positional argument is the path to the VQ-VAE-2 checkpoint. The remaining L positional arguments are the PixelSnail prior checkpoints from level 0 to L.

python main-sample.py vq_vae_path.pt pixelsnail_0_path.pt pixel_snail_1_path.pt ... --task task_name

Other useful flags:

--cpu           # do not use GPU
--batch-size    # overrides batch size in cfg.py, useful for evaluating on larger batch size
--nb-samples    # number of samples to generate. defaults to 1.
--no-tqdm       # disable tqdm status bars
--no-save       # disables saving of files
--no-amp        # disables using native AMP (Automatic Mixed Precision) operations
--save-jpg      # save all images as jpg instead of png, useful for extreme resolutions
--temperature   # controls softmax temperature during sampling

Modifications

  • Replacing residual layers with ReZero layers.

Samples

Reconstructions from FFHQ1024 using a 3-level VQ-VAE-2 Reconstructions from FFHQ1024 using a 3-level VQ-VAE-2

Checkpoints

FFHQ1024 - 3-level VQ-VAE-2

FFHQ1024 - 5-level VQ-VAE-2

Roadmap

  • Server mode (no fancy printing)
  • Experiment directories (containing logs / checkpoints / etc)
  • Accumulated gradient training (for larger batch sizes on limited resources)
  • Samples and checkpoints on FFHQ1024
  • Latent dataset generation
  • Autoregressive prior models / training scripts
  • Full system sampling
  • Prettier outputs
  • Output logging

Citations

Generating Diverse High-Fidelity Images with VQ-VAE-2

@misc{razavi2019generating,
      title={Generating Diverse High-Fidelity Images with VQ-VAE-2}, 
      author={Ali Razavi and Aaron van den Oord and Oriol Vinyals},
      year={2019},
      eprint={1906.00446},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

PixelSNAIL: An Improved Autoregressive Generative Model

@misc{chen2017pixelsnail,
      title={PixelSNAIL: An Improved Autoregressive Generative Model}, 
      author={Xi Chen and Nikhil Mishra and Mostafa Rohaninejad and Pieter Abbeel},
      year={2017},
      eprint={1712.09763},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

ReZero is All You Need: Fast Convergence at Large Depth

@misc{bachlechner2020rezero,
      title={ReZero is All You Need: Fast Convergence at Large Depth}, 
      author={Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
      year={2020},
      eprint={2003.04887},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

More Repositories

1

mamba-jax

Unofficial but Efficient Implementation of "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" in JAX
Python
74
star
2

TchAIkovsky

Using JAX to generate piano music as MIDI
Python
37
star
3

sundae

Unofficial PyTorch implementation of "Step-unrolled Denoising Autoencoders for Text Generation"
Python
22
star
4

mezo-jax

JAX implementation of "Fine-Tuning Language Models with Just Forward Passes"
Python
20
star
5

ptpt

PyTorch Personal Trainer: My framework for deep learning experiments
Python
10
star
6

stoi-vqcpc

Repository for paper "Non-intrusive speech intelligibility prediction from discrete latent representations"
Python
10
star
7

dl_network_analyser

Deep Learning Network Traffic Analyser
Python
6
star
8

specgrad

To be an (Unofficial) implementation of "SpecGrad: Diffusion Probabilistic Model based Neural Vocoder with Adaptive Noise Spectral Shaping" in PyTorch
Jupyter Notebook
5
star
9

glow

PyTorch implementation of "Glow: Generative Flow with Invertible 1x1 Convolutions"
Python
4
star
10

walter

walter
Rust
4
star
11

sundae-vqgan

Official Code for Paper "Megapixel Image Generation with Step-Unrolled Denoising Autoencoders"
Python
3
star
12

sss

Simple Static Sites
Rust
2
star
13

ddpm

PyTorch implementation of "Denoising Diffusion Probabilistic Models"
Python
2
star
14

PixelSnail

PyTorch implementation of "PixelSNAIL: An Improved Autoregressive Generative Model"
Python
2
star
15

TchAIkovsky-Legacy

Generating piano performances in MIDI using Transformer architectures
Python
2
star
16

learn-jax

Repository to store my learning experiences with JAX
Python
2
star
17

ADS_Algorithms

Algorithms in ADS Level 1
Python
1
star
18

vimrc

Personal (neo)vim configuration
Vim Script
1
star
19

Programming_Formative_Assessment

Repository for group assessment on html/css
HTML
1
star
20

vvvm23.github.io

HTML
1
star
21

ml-experiments

Collection of my ML/AI experiments that don't mandate their own repository
1
star
22

vdvae

PyTorch implementation of Very Deep VAE (VD-VAE) from "Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images"
Python
1
star
23

dotfiles

Collection of my dotfiles
Shell
1
star
24

seqpool

Pytorch library implementing various pooling methods for sequences of embedding vectors
Python
1
star