• Stars
    star
    512
  • Rank 83,594 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 7 years ago
  • Updated about 5 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A Tensorflow implementation of Semi-supervised Learning Generative Adversarial Networks (NIPS 2016: Improved Techniques for Training GANs).

Semi-supervised learning GAN in Tensorflow

As part of the implementation series of Joseph Lim's group at USC, our motivation is to accelerate (or sometimes delay) research in the AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our group github site for other projects.

This project is implemented by Shao-Hua Sun and the codes have been reviewed by Jiayuan Mao before being published.

Descriptions

This project is a Tensorflow implementation of Semi-supervised Learning Generative Adversarial Networks proposed in the paper Improved Techniques for Training GANs. The intuition is exploiting the samples generated by GAN generators to boost the performance of image classification tasks by improving generalization.

In sum, the main idea is training a network playing both the roles of a classifier performing image classification task as well as a discriminator trained to distinguish generated samples produced by a generator from the real data. To be more specific, the discriminator/classifier takes an image as input and classified it into n+1 classes, where n is the number of classes of a classification task. True samples are classified into the first n classes and generated samples are classified into the n+1-th class, as shown in the figure below.

The loss of this multi-task learning framework can be decomposed into the supervised loss

,

and the GAN loss of a discriminator

,

During the training phase, we jointly minimize the total loss obtained by simply combining the two losses together.

The implemented model is trained and tested on three publicly available datasets: MNIST, SVHN, and CIFAR-10.

Note that this implementation only follows the main idea of the original paper while differing a lot in implementation details such as model architectures, hyperparameters, applied optimizer, etc. Also, some useful training tricks applied to this implementation are stated at the end of this README.

*This code is still being developed and subject to change.

Prerequisites

Usage

Download datasets with:

$ python download.py --dataset MNIST SVHN CIFAR10

Train models with downloaded datasets:

$ python trainer.py --dataset MNIST
$ python trainer.py --dataset SVHN
$ python trainer.py --dataset CIFAR10

Test models with saved checkpoints:

$ python evaler.py --dataset MNIST --checkpoint ckpt_dir
$ python evaler.py --dataset SVHN --checkpoint ckpt_dir
$ python evaler.py --dataset CIFAR10 --checkpoint ckpt_dir

The ckpt_dir should be like: train_dir/default-MNIST_lr_0.0001_update_G5_D1-20170101-194957/model-1001

Train and test your own datasets:

  • Create a directory
$ mkdir datasets/YOUR_DATASET
  • Store your data as an h5py file datasets/YOUR_DATASET/data.hy and each data point contains
    • 'image': has shape [h, w, c], where c is the number of channels (grayscale images: 1, color images: 3)
    • 'label': represented as an one-hot vector
  • Maintain a list datasets/YOUR_DATASET/id.txt listing ids of all data points
  • Modify trainer.py including args, data_info, etc.
  • Finally, train and test models:
$ python trainer.py --dataset YOUR_DATASET
$ python evaler.py --dataset YOUR_DATASET

Results

MNIST

  • Generated samples (100th epochs)

  • First 40 epochs

SVHN

  • Generated samples (100th epochs)

  • First 160 epochs

CIFAR-10

  • Generated samples (1000th epochs)

  • First 200 epochs

Training details

MNIST

  • The supervised loss

  • The loss of Discriminator

D_loss_real

D_loss_fake

D_loss (total loss)

  • The loss of Generator

G_loss

  • Classification accuracy

SVHN

  • The supervised loss

  • The loss of Discriminator

D_loss_real

D_loss_fake

D_loss (total loss)

  • The loss of Generator

G_loss

  • Classification accuracy

CIFAR-10

  • The supervised loss

  • The loss of Discriminator

D_loss_real

D_loss_fake

D_loss (total loss)

  • The loss of Generator

G_loss

  • Classification accuracy

Training tricks

  • To avoid the fast convergence of the discriminator network
    • The generator network is updated more frequently.
    • Higher learning rate is applied to the training of the generator.
  • One-sided label smoothing is applied to the positive labels.
  • Gradient clipping trick is applied to stablize training
  • Reconstruction loss with an annealed weight is applied as an auxiliary loss to help the generator get rid of the initial local minimum.
  • Utilize Adam optimizer with higher momentum.
  • Please refer to the codes for more details.

Related works

Acknowledgement

Part of codes is from an unpublished project with Jongwook Choi

More Repositories

1

awesome-rl-envs

958
star
2

furniture

IKEA Furniture Assembly Environment for Long-Horizon Complex Manipulation Tasks
Python
488
star
3

Relation-Network-Tensorflow

Tensorflow implementations of Relational Networks and a VQA dataset named Sort-of-CLEVR proposed by DeepMind.
Python
322
star
4

ACGAN-PyTorch

Python
258
star
5

spirl

Official implementation of "Accelerating Reinforcement Learning with Learned Skill Priors", Pertsch et al., CoRL 2020
Python
177
star
6

Representation-Learning-by-Learning-to-Count

A Tensorflow implementation of Representation Learning by Learning to Count
Python
106
star
7

BicycleGAN-Tensorflow

A Tensorflow implementation of BicycleGAN.
Python
97
star
8

Generative-Latent-Optimization-Tensorflow

Tensorflow implementation of Generative Latent Optimization (GLO) proposed by Facebook AI Research
Python
95
star
9

furniture-bench

FurnitureBench: Real-World Furniture Assembly Benchmark (RSS 2023)
Python
91
star
10

mopa-rl

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)
Python
67
star
11

MAML-tf

Tensorflow Implementation of MAML
Python
52
star
12

skimo

Skill-based Model-based Reinforcement Learning (CoRL 2022)
Python
43
star
13

coordination

Learning to Coordinate Manipulation Skills via Skill Behavior Diversification (ICLR 2020)
Python
41
star
14

FeatureControlHRL-Tensorflow

A Tensorflow implementation of Feature Control as Intrinsic Motivation for Hierarchical Reinforcement Learning
Python
31
star
15

skill-chaining

Adversarial Skill Chaining for Long-Horizon Robot Manipulation via Terminal State Regularization (CoRL 2021)
Python
29
star
16

CycleGAN-Tensorflow

A Tensorflow implementation of Cycle-Consistent Adversarial Networks.
Python
26
star
17

leaps

Code for Learning to Synthesize Programs as Interpretable and Generalizable Policies in NeurIPS 2021
Python
26
star
18

new-actions-rl

Jupyter Notebook
23
star
19

skild

Python
22
star
20

i2a-tf

Imagination Augmented Agents in TensorFlow
Python
20
star
21

goal_prox_il

Generalizable Imitation Learning from Observation via Inferring Goal Proximity (NeurIPS 2021)
Python
18
star
22

create

CREATE Environment for long-horizon physics-puzzle tasks with diverse tools
Python
18
star
23

agile

Official implementation of "Know Your Action Set: Learning Action Relations for Reinforcement Learning", Jain et al., ICLR 2022.
Python
14
star
24

boss

Code for the paper Bootstrap Your Own Skills: Learning to Solve New Tasks with Large Language Model Guidance, accepted to CoRL 2023 as an Oral Presentation.
14
star
25

DiscoGAN-Tensorflow

A Tensorflow implementation of DiscoGAN.
Python
8
star
26

tarp

Official implementation of "Task-Induced Representation Learning", Yamada et al., ICLR 2022
Python
8
star
27

sprint

Code and website for for SPRINT: Scalable Policy Pre-Training via Language Instruction Relabeling
Python
8
star
28

idapt

Policy Transfer across Visual and Dynamics Domain Gaps via Iterative Grounding (RSS 2021)
Python
8
star
29

mopa-pd

Distilling Motion Planner Augmented Policies into Visual Control Policies for Robot Manipulation (CoRL 2021)
Python
7
star
30

clvr_jaco_play_dataset

Official release of the CLVR Jaco Play Dataset, Dass et al. 2023
Shell
4
star