• Stars
    star
    1,134
  • Rank 41,064 (Top 0.9 %)
  • Language
    Python
  • License
    MIT License
  • Created over 8 years ago
  • Updated 8 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Variational autoencoder implemented in tensorflow and pytorch (including inverse autoregressive flow)

Variational Autoencoder in tensorflow and pytorch

DOI

Reference implementation for a variational autoencoder in TensorFlow and PyTorch.

I recommend the PyTorch version. It includes an example of a more expressive variational family, the inverse autoregressive flow.

Variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder).

Blog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/

PyTorch implementation

(anaconda environment is in environment-jax.yml)

Importance sampling is used to estimate the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. The final marginal likelihood on the test set was -97.10 nats is comparable to published numbers.

$ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000
Step 0          Train ELBO estimate: -558.027   Validation ELBO estimate: -384.432      Validation log p(x) estimate: -355.430  Speed: 2.72e+06 examples/s
Step 10000      Train ELBO estimate: -111.323   Validation ELBO estimate: -109.048      Validation log p(x) estimate: -103.746  Speed: 2.64e+04 examples/s
Step 20000      Train ELBO estimate: -103.013   Validation ELBO estimate: -107.655      Validation log p(x) estimate: -101.275  Speed: 2.63e+04 examples/s
Step 29999      Test ELBO estimate: -106.642    Test log p(x) estimate: -100.309
Total time: 2.49 minutes

Using a non mean-field, more expressive variational posterior approximation (inverse autoregressive flow, https://arxiv.org/abs/1606.04934), the test marginal log-likelihood improves to -95.33 nats:

$ python train_variational_autoencoder_pytorch.py --variational flow
step:   0       train elbo: -578.35
step:   0               valid elbo: -407.06     valid log p(x): -367.88
step:   10000   train elbo: -106.63
step:   10000           valid elbo: -110.12     valid log p(x): -104.00
step:   20000   train elbo: -101.51
step:   20000           valid elbo: -105.02     valid log p(x): -99.11
step:   30000   train elbo: -98.70
step:   30000           valid elbo: -103.76     valid log p(x): -97.71

jax implementation

Using jax (anaconda environment is in environment-jax.yml), to get a 3x speedup over pytorch:

$ python train_variational_autoencoder_jax.py --variational mean-field 
Step 0          Train ELBO estimate: -566.059   Validation ELBO estimate: -565.755      Validation log p(x) estimate: -557.914  Speed: 2.56e+11 examples/s
Step 10000      Train ELBO estimate: -98.560    Validation ELBO estimate: -105.725      Validation log p(x) estimate: -98.973   Speed: 7.03e+04 examples/s
Step 20000      Train ELBO estimate: -109.794   Validation ELBO estimate: -105.756      Validation log p(x) estimate: -97.914   Speed: 4.26e+04 examples/s
Step 29999      Test ELBO estimate: -104.867    Test log p(x) estimate: -96.716
Total time: 0.810 minutes

Inverse autoregressive flow in jax:

$ python train_variational_autoencoder_jax.py --variational flow 
Step 0          Train ELBO estimate: -727.404   Validation ELBO estimate: -726.977      Validation log p(x) estimate: -713.389  Speed: 2.56e+11 examples/s
Step 10000      Train ELBO estimate: -100.093   Validation ELBO estimate: -106.985      Validation log p(x) estimate: -99.565   Speed: 2.57e+04 examples/s
Step 20000      Train ELBO estimate: -113.073   Validation ELBO estimate: -108.057      Validation log p(x) estimate: -98.841   Speed: 3.37e+04 examples/s
Step 29999      Test ELBO estimate: -106.803    Test log p(x) estimate: -97.620
Total time: 2.350 minutes

(The difference between a mean field and inverse autoregressive flow may be due to several factors, chief being the lack of convolutions in the implementation. Residual blocks are used in https://arxiv.org/pdf/1606.04934.pdf to get the ELBO closer to -80 nats.)

Generating the GIFs

  1. Run python train_variational_autoencoder_tensorflow.py
  2. Install imagemagick (homebrew for Mac: https://formulae.brew.sh/formula/imagemagick or Chocolatey in Windows: https://community.chocolatey.org/packages/imagemagick.app)
  3. Go to the directory where the jpg files are saved, and run the imagemagick command to generate the .gif: convert -delay 20 -loop 0 *.jpg latent-space.gif

More Repositories

1

food2vec

🍔
Jupyter Notebook
221
star
2

jaan.io

A Retina-ready Jekyll-powered blog with responsiveness, SEO, etc.; up at https://jaan.io
HTML
55
star
3

hierarchical-variational-models-physics

Hierarchical variational models for physics.
Jupyter Notebook
18
star
4

proximity_vi

This code accompanies the proximity variational inference paper.
Python
18
star
5

deep-exponential-families-gluon

Deep exponential family models in MXNet/Gluon. Layers o' latents 💤
Python
17
star
6

sentence_word2vec

word2vec with a context based on sentences.
Python
15
star
7

american-community-survey

American Community Survey data on people and households
Jupyter Notebook
15
star
8

gamma-variational-autoencoder

Deep Latent Gamma Model / Gamma VAE
Python
13
star
9

couchometer

Instead of classifying activity, this app does one thing: tells you how much you sit based on accelerometer data.
Java
8
star
10

thesis

Altosaar, Jaan (2020). Probabilistic Modeling of Structure in Science: Statistical Physics to Recommender Systems. Ph.D. Thesis, Princeton University.
TeX
7
star
11

vimco_tf

VIMCO in tensorflow.
Python
5
star
12

exploring_american_community_survey_data

Using the Census Bureau's American Community Survey data with `dbt` (data build tool) for creating compressed parquet files for exploratory data analysis and downstream applications.
Jupyter Notebook
5
star
13

rankfromsets

RankFromSets - SDSS submission code for reproducibility.
HTML
4
star
14

vae-lstm

Variational autoencoder LSTMs for time series data.
Python
4
star
15

citibike-stats

Calculating CitiBike personal statistics for the community leaderboard 📈 🚴
Jupyter Notebook
4
star
16

gmm_cpp

Gaussian mixture model implementation in C++ with black box variational inference and control variates
C++
3
star
17

overleaf-curriculum-vitae-resume-cv-template

TeX
3
star
18

dotfiles

Dotfiles - main development environment is vim + tmux or emacs
Python
3
star
19

language-model-notebooks

Quickstart for interacting with language models and APIs via a notebook-like interface
Jupyter Notebook
2
star
20

nomen

🐐 Lightweight configuration trees with command line flags 🐐
Python
2
star
21

ctpf

User-artist-song Poisson Factorization
Python
2
star
22

physical-monotile-printing

1
star
23

jaan.li

Personal website using Observable Framework
JavaScript
1
star
24

new-york-real-estate

Jupyter Notebook
1
star
25

user-artist-song-poisson-factorization

TeX
1
star
26

CumulantExpander

Automated expansions of Hamiltonian cumulants for analyzing Monte Carlo simulations.
1
star