• Stars
    star
    257
  • Rank 158,728 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created over 5 years ago
  • Updated over 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code for Neural Spline Flows paper

Neural Spline Flows

A record of the code and experiments for the paper:

C. Durkan, A. Bekasov, I. Murray, G. Papamakarios, Neural Spline Flows, NeurIPS 2019. [arXiv] [bibtex]

Work in this repository has now stopped. Please go to nflows for an updated and pip-installable normalizing flows framework for PyTorch.

Dependencies

See environment.yml for required Conda/pip packages, or use this to create a Conda environment with all dependencies:

conda env create -f environment.yml

Tested with Python 3.5 and PyTorch 1.1.

Data

Data for density-estimation experiments is available at https://zenodo.org/record/1161203#.Wmtf_XVl8eN.

Data for VAE and image-modeling experiments is downloaded automatically using either torchvision or custom data providers.

Usage

DATAROOT environment variable needs to be set before running experiments.

2D toy density experiments

Use experiments/face.py or experiments/plane.py.

Density-estimation experiments

Use experiments/uci.py.

VAE experiments

Use experiments/vae_.py.

Image-modeling experiments

Use experiments/images.py.

Sacred is used to organize image experiments. See the documentation for more information.

experiments/image_configs contains .json configurations used for RQ-NSF (C) experiments. For baseline experiments use coupling_layer_type='affine'.

For example, to run RQ-NSF (C) on CIFAR-10 8-bit:

python experiments/images.py with experiments/image_configs/cifar-10-8bit.json

Corresponding affine baseline run:

python experiments/images.py with experiments/image_configs/cifar-10-8bit.json coupling_layer_type='affine'

To evaluate on the test set:

python experiments/images.py eval_on_test with experiments/image_configs/cifar-10-8bit.json flow_checkpoint='<saved_checkpoint>'

To sample:

python experiments/images.py sample with experiments/image_configs/cifar-10-8bit.json flow_checkpoint='<saved_checkpoint>'