• Stars
    star
    146
  • Rank 252,769 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created about 2 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Normalizing flows in PyTorch

Zuko's banner

Zuko - Normalizing flows in PyTorch

Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the Distribution and Transform classes of torch are not sub-classes of torch.nn.Module, which means you cannot send their internal tensors to GPU with .to('cuda') or retrieve their parameters with .parameters().

To solve this problem, zuko defines two abstract classes: DistributionModule and TransformModule. The former is any Module whose forward pass returns a Distribution and the latter is any Module whose forward pass returns a Transform. A normalizing flow is just a DistributionModule which contains a list of TransformModule and a base DistributionModule. This design allows for flows that behave like distributions while retaining the benefits of Module. It also makes the implementations easier to understand and extend.

In the Avatar cartoon, Zuko is a powerful firebender 🔥

Installation

The zuko package is available on PyPI, which means it is installable via pip.

pip install zuko

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/francois-rozet/zuko

Getting started

Normalizing flows are provided in the zuko.flows module. To build one, supply the number of sample and context features as well as the transformations' hyperparameters. Then, feeding a context y to the flow returns a conditional distribution p(x | y) which can be evaluated and sampled from.

import torch
import zuko

# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = zuko.flows.NSF(3, 5, transforms=3, hidden_features=[128] * 3)

# Train to maximize the log-likelihood
optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3)

for x, y in trainset:
    loss = -flow(y).log_prob(x)  # -log p(x | y)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Sample 64 points x ~ p(x | y*)
x = flow(y_star).sample((64,))

For more information, check out the documentation at zuko.readthedocs.io.

Available flows

Class Year Reference
MAF 2017 Masked Autoregressive Flow for Density Estimation
NSF 2019 Neural Spline Flows
NCSF 2020 Normalizing Flows on Tori and Spheres
SOSPF 2019 Sum-of-Squares Polynomial Flow
NAF 2018 Neural Autoregressive Flows
UNAF 2019 Unconstrained Monotonic Neural Networks
CNF 2018 Neural Ordinary Differential Equations

Contributing

If you have a question, an issue or would like to contribute, please read our contributing guidelines.

More Repositories

1

piqa

PyTorch Image Quality Assessement package
Python
400
star
2

sleek-template

Sleek Template for quick, easy and beautiful LaTeX documents
TeX
140
star
3

lampe

Likelihood-free AMortized Posterior Estimation with PyTorch
Python
72
star
4

torchist

NumPy-style histograms in PyTorch
Python
51
star
5

sda

Official implementation of Score-based Data Assimilation
Python
37
star
6

inox

Stainless neural networks in JAX
Python
30
star
7

postr

A minimal poster template in Typst.
Typst
19
star
8

papers-101

Implementation of papers in 101 lines of code.
Python
18
star
9

dawgz

Unleash the true power of scheduling
Python
17
star
10

diffusion-priors

Learning Diffusion Priors from Observations by Expectation Maximization
Python
17
star
11

adopptrs

Automatic Detection Of Photovoltaic Panels Through Remote Sensing
Python
16
star
12

sleek-beamer

LaTeX sleek beamer template
TeX
9
star
13

sudoku

Sudoku grid and digits detection
Python
4
star
14

vsop-compiler

Implementation of a VSOP compiler
C++
4
star
15

uci-datasets

UCI datasets from the MAF paper
Python
3
star
16

sleek-poster

LaTeX sleek poster template
TeX
2
star
17

elen0060-2

Projects of information and coding theory
Python
2
star
18

info2049-1

Sentiment analysis using deep learning methods.
Python
2
star
19

amnre

Arbitrary Marginal Neural Ratio Estimation for Likelihood-free Inference
Python
2
star
20

benchmark_error

Python
1
star
21

info0054-1

Projet de programmation fonctionnelle
Scheme
1
star
22

proj0001-1

Projet d'application de méthodes numériques
MATLAB
1
star
23

elen0016-2

Computer vision project
Python
1
star
24

whitespacy

Polyglot formatter for C and Whitespace
Python
1
star
25

info8003-1

Assignments of Reinforcement Learning
Python
1
star
26

math0462-1

Project of discrete optimization
Julia
1
star
27

math0488-1

Projet d'étude de processus stochastiques
MATLAB
1
star
28

syst0002-2

Projet d'analyse de systèmes
MATLAB
1
star