• Stars
    star
    104
  • Rank 324,288 (Top 7 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 7 years ago
  • Updated over 6 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code for "Flow-GAN: Combining Maximum Likelihood and Adversarial Learning in Generative Models", AAAI 2018.

Flow-GAN: Combining Maximum Likelihood and Adversarial Learning in Generative Models

This repository provides a reference implementation for learning Flow-GAN models as described in the paper:

Flow-GAN: Combining Maximum Likelihood and Adversarial Learning in Generative Models
Aditya Grover, Manik Dhar, and Stefano Ermon.
AAAI Conference on Artificial Intelligence (AAAI), 2018.
Paper: https://arxiv.org/pdf/1705.08868.pdf
Blog post: https://ermongroup.github.io/blog/flow-gan

Requirements

The codebase is implemented in Python 3.6. To install the necessary requirements, run the following commands:

pip install -r requirements.txt

Datasets

The scripts for downloading and loading the MNIST and CIFAR10 datasets are included in the datasets_loader folder. These scripts will be called automatically the first time the main.py script is run.

Options

Learning and inference of Flow-GAN models is handled by the main.py script which provides the following command line arguments.

  --beta1 FLOAT           beta1 parameter for Adam optimizer
  --epoch INT             number of epochs to train
  --batch_size FLOAT      training batch size
  --learning_rate FLOAT   learning rate
  --input_height INT      The size of image to use
  --input_width INT       The size of image to use if none given use same value as input height
  --c_dim INT             Dimension of image color
  --dataset STR           The name of dataset [mnist, svhn, cifar-10]
  --checkpoint_dir STR    Directory name to save the checkpoints
  --log_dir STR           Directory name to save the logs
  --sample_dir STR        Directory name to save the image samples
  --f_div STR             divergence used for specifying the gan objective
  --prior STR             prior for generator
  --alpha FLOAT           alpha value for applying logits
  --lr_decay FLOAT        Learning rate decay rate
  --min_lr FLOAT          minimum lr allowed on decay
  --reg FLOAT             regularization parameter for adversarial training
  --model_type STR        real_nvp or nice
  --n_critic INT          no of discriminator iterations
  --no_of_layers INT      No of units between input and output in the m function for a coupling layer
  --hidden_layers INT     Size of hidden layers (applicable only for NICE)
  --like_reg FLOAT        regularizing factor for likelihood vs. adversarial losses for hybrid
  --df_dim FLOAT          Dim depth for discriminator

Examples

Training flow-GAN models on the MNIST dataset with NICE architecture.

Maximum Likelihood Estimation (MLE)

python main.py --dataset mnist --input_height=28 --c_dim=1  --checkpoint_dir checkpoint_mnist/mle --sample_dir samples_mnist/mle --model_type nice --log_dir logs_mnist/mle 
--prior logistic --beta1 0.5 --learning_rate 1e-4 --alpha 1e-7 --epoch 500 --batch_size 100 --like_reg 1.0 --n_critic 0 --no_of_layers 5

Adversarial training (ADV)

python main.py --dataset mnist --input_height=28 --c_dim=1  --checkpoint_dir checkpoint_mnist/gan --sample_dir samples_mnist/gan --model_type nice --log_dir logs_mnist/gan 
--prior logistic --beta1 0.5 --learning_rate 1e-4 --alpha 1e-7 --reg 10.0 --epoch 500 --batch_size 100 --like_reg 0.0 --n_critic 5 --no_of_layers 5

Hybrid

python main.py --dataset mnist --input_height=28 --c_dim=1  --checkpoint_dir checkpoint_mnist/flow --sample_dir samples_mnist/flow --model_type nice --log_dir logs_mnist/flow 
--prior logistic --beta1 0.5 --learning_rate 1e-4 --alpha 1e-7 --reg 10.0 --epoch 500 --batch_size 100 --like_reg 1.0 --n_critic 5 --no_of_layers 5

Training flow-GAN models on the CIFAR dataset with Real-NVP architecture.

Maximum Likelihood Estimation (MLE)

python main.py --dataset cifar --input_height=32 --c_dim=3  --checkpoint_dir checkpoint_cifar/mle --sample_dir samples_cifar/mle --model_type real_nvp --log_dir logs_cifar/mle 
--prior gaussian --beta1 0.9 --learning_rate 1e-3 --alpha 1e-7 --epoch 300 --lr_decay 0.999995 --batch_size 64 --like_reg 1.0 --n_critic 0 --no_of_layers 8 --batch_norm_adaptive 0

Adversarial training (ADV)

python main.py --dataset cifar --input_height=32 --c_dim=3  --checkpoint_dir checkpoint_cifar/gan --sample_dir samples_cifar/gan --model_type real_nvp --log_dir logs_cifar/gan 
--prior gaussian --beta1 0.5 --learning_rate 1e-4 --alpha 1e-7 --epoch 300 --batch_size 64 --like_reg 0.0  --n_critic 5 --no_of_layers 8

Hybrid

python main.py --dataset cifar --input_height=32 --c_dim=3  --checkpoint_dir checkpoint_cifar/flow --sample_dir samples_cifar/flow --model_type real_nvp --log_dir logs_cifar/flow 
--prior gaussian --beta 0.5 --learning_rate 1e-3 --lr_decay 0.99999 --alpha 1e-7 --epoch 500 --batch_size 64 --like_reg 20.  --n_critic 5 --no_of_layers 8

Portions of the codebase in this repository uses code originally provided in the open-source DCGAN and Real-NVP repositories.

Citing

If you find flow-GANs useful in your research, please consider citing the following paper:

@inproceedings{grover2018flowgan,
title={Flow-GAN: Combining Maximum Likelihood and Adversarial Learning in Generative Models},
author={Grover, Aditya and Dhar, Manik and Ermon, Stefano},
booktitle={AAAI Conference on Artificial Intelligence},
year={2018}}

More Repositories

1

cs228-notes

Course notes for CS228: Probabilistic Graphical Models.
SCSS
1,863
star
2

ddim

Denoising Diffusion Implicit Models
Python
1,300
star
3

SDEdit

PyTorch implementation for SDEdit: Image Synthesis and Editing with Stochastic Differential Equations
Python
933
star
4

ncsn

Noise Conditional Score Networks (NeurIPS 2019, Oral)
Python
630
star
5

ncsnv2

The official PyTorch implementation for NCSNv2 (NeurIPS 2020)
Python
262
star
6

CSDI

Codes for "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation"
Jupyter Notebook
253
star
7

Wifi_Activity_Recognition

Code for IEEE Communication Magazine (A Survey on Behaviour Recognition Using WiFi Channle State Information)
Jupyter Notebook
237
star
8

Variational-Ladder-Autoencoder

Implementation of VLAE
Python
216
star
9

MA-AIRL

Multi-Agent Adversarial Inverse Reinforcement Learning, ICML 2019.
Python
181
star
10

sliced_score_matching

Code for reproducing results in the sliced score matching paper (UAI 2019)
Python
133
star
11

neuralsort

Code for "Stochastic Optimization of Sorting Networks using Continuous Relaxations", ICLR 2019.
Python
133
star
12

a-nice-mc

Code for "A-NICE-MC: Adversarial Training for MCMC"
Jupyter Notebook
126
star
13

tile2vec

Implementation and examples for Tile2Vec
Python
110
star
14

GraphScoreMatching

Official implementation for the paper: Permutation Invariant Graph Generation via Score-Based Generative Modeling
Python
97
star
15

Sequential-Variational-Autoencoder

Implementation of Sequential Variational Autoencoder
Python
84
star
16

multiagent-gail

Python
80
star
17

markov-chain-gan

Code for "Generative Adversarial Training for Markov Chains" (ICLR 2017 Workshop)
Python
79
star
18

ssdkl

Code that accompanies the paper Semi-supervised Deep Kernel Learning: Regression with Unlabeled Data by Minimizing Predictive Variance
Python
72
star
19

MetaIRL

Meta-Inverse Reinforcement Learning with Probabilistic Context Variables
Python
68
star
20

smile-mi-estimator

PyTorch implementation for the ICLR 2020 paper "Understanding the Limitations of Variational Mutual Information Estimators"
Jupyter Notebook
68
star
21

PatchDrop

PyTorch Implementation of `Learning to Process Fewer Pixels` - [CVPR20 (Oral)]
Python
66
star
22

generative_adversary

Code for the unrestricted adversarial examples paper (NeurIPS 2018)
Python
63
star
23

pirank

PiRank: Learning to Rank via Differentiable Sorting
Python
60
star
24

graphite

Code for Graphite iterative graph generation
Python
55
star
25

CalibratedModelBasedRL

Code for "Calibrated Model-Based Deep Reinforcement Learning", ICML 2019.
Python
54
star
26

ODS

Code for "Diversity can be Transferred: Output Diversification for White- and Black-box Attacks"
Python
53
star
27

subsets

Code for Reparameterizable Subset Sampling via Continuous Relaxations, IJCAI 2019.
Python
49
star
28

necst

Neural Joint-Source Channel Coding
Python
44
star
29

cs323-notes

Course notes for CS323: Automated Reasoning
CSS
40
star
30

mintnet

MintNet: Building Invertible Neural Networks with Masked Convolutions
Python
38
star
31

f-EBM

Code for "Training Deep Energy-Based Models with f-Divergence Minimization" ICML 2020
Python
35
star
32

alignflow

Python
33
star
33

higher_order_invariance

Code for "Accelerating Natural Gradient with Higher-Order Invariance"
MATLAB
29
star
34

lagvae

Lagrangian VAE
Python
28
star
35

BiasAndGeneralization

Jupyter Notebook
26
star
36

BCD-Nets

Code for `BCD Nets: Scalable Variational Approaches for Bayesian Causal Discovery`, Neurips 2021
Python
24
star
37

fast_feedforward_computation

Official code for "Accelerating Feedforward Computation via Parallel Nonlinear Equation Solving", ICML 2021
Jupyter Notebook
24
star
38

Crop_Yield_Prediction

Python
23
star
39

NDA

Python
23
star
40

sparse_gen

Code for "Modeling Sparse Deviations for Compressed Sensing using Generative Models", ICML 2018
Python
23
star
41

self-similarity-prior

Self-Similarity Priors: Neural Collages as Differentiable Fractal Representations
Jupyter Notebook
22
star
42

dail

The Official Implementation of Domain Adaptive Imitation Learning (DAIL)
Python
22
star
43

lag-fairness

Python
22
star
44

STGAN

PyTorch Implementation of STGAN for Cloud Removal in Satellite Images.
Python
22
star
45

bgm

Code for "Boosted Generative Models", AAAI 2018.
Python
20
star
46

best-arm-delayed

Code for "Best arm identification in multi-armed bandits with delayed feedback", AISTATS 2018.
Python
19
star
47

f-dre

Featurized Density Ratio Estimation
Jupyter Notebook
18
star
48

WikipediaPovertyMapping

Implementation of Geolocated Articles Processing and Poverty Mapping - [KDD19]
Jupyter Notebook
18
star
49

fairgen

Fair Generative Modeling via Weak Supervision
Jupyter Notebook
18
star
50

Neural-PDE-Solver

Python
15
star
51

SPN_Variational_Inference

PyTorch implementation for "Probabilistic Circuits for Variational Inference in Discrete Graphical Models", NeurIPS 2020
Python
15
star
52

acl

Code for "Adversarial Constraint Learning for Structured Prediction"
Python
14
star
53

f-wgan

Code for "Bridging the Gap between f-GANs and Wasserstein GANs", ICML 2020
Jupyter Notebook
14
star
54

HyperSPN

PyTorch implementation for "HyperSPNs: Compact and Expressive Probabilistic Circuits", NeurIPS 2021
Python
13
star
55

dre-infinity

Density Ratio Estimation via Infinitesimal Classification (AISTATS 2022 Oral)
Python
13
star
56

EfficientObjectDetection

PyTorch Implementation of Efficient Object Detection in Large Images
Python
8
star
57

streamline-vi-csp

C
7
star
58

bayes-opt

Python
4
star
59

BestArmIdentification

Python
3
star
60

permanent_adaptive

Python
3
star
61

rbpf_fireworks

Python
2
star
62

PretrainingWikiSatNet

Python
2
star
63

pestat

Keep pestat great
Shell
2
star
64

weighted-rademacher

Python
2
star
65

gac

Python
2
star