• Stars
    star
    117
  • Rank 301,828 (Top 6 %)
  • Language
    Python
  • Created over 7 years ago
  • Updated about 7 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implement Decoupled Neural Interfaces using Synthetic Gradients in Pytorch

disclaimer: this code is modified from pytorch-tutorial

Image classification with synthetic gradient in Pytorch

I implement the Decoupled Neural Interfaces using Synthetic Gradients in pytorch. The paper uses synthetic gradient to decouple the layers among the network, which is pretty interesting since we won't suffer from update lock anymore. I test my model in mnist and almost the same performance, compared to the model updated with backpropagation.

Requirement

  • pytorch
  • python 3.5
  • torchvision
  • seaborn (optional)
  • matplotlib (optional)

TODO

  • use multi-threading on gpu to analyze the speed

What's synthetic gradients?

We ofter optimize NN by backpropogation, which is usually implemented in some well-known framework. However, is there another way for the layers in NN to communicate with other layers? Here comes the synthetic gradients! It gives us a way to allow neural networks to communicate, to learn to send messages between themselves, in a decoupled, scalable manner paving the way for multiple neural networks to communicate with each other or improving the long term temporal dependency of recurrent networks.
The neuron in each layer will automatically produces an error signal(δa_head) from synthetic-layers and do the optimzation. And how did the error signal generated? Actually, the network still does the backpropogation. While the error signal(δa) from the objective function is not used to optimize the neuron in the network, it is used to optimize the error signal(δa_head) produced by the synthetic-layer. The following is the illustration from the paper:

Result

Feed-Forward Network

Achieve accuracy=96% (compared to the original model, which with accuracy=97%)

classify loss gradient loss(log level)
cDNI classify loss cDNI gradient loss(log level)

Convolutional Neural Network

Achieve accuracy=96%, (compared to the original model, which with accuracy=98%)

classify loss gradient loss(log level)

Usage

Right now I just implement the FCN, CNN versions, which are set as the default network structure.

Run network with synthetic gradient:

python main.py --model_type mlp

or

python main.py --model_type cnn

Run network with conditioned synthetic gradient:

python main.py --model_type mlp --conditioned True

Run vanilla network, from pytorch-tutorial

python mlp.py

or

python cnn.py

Reference

More Repositories

1

Deep-Reinforcement-Learning-Survey

My Exploration on Deep Reinforcement Learning Survey
427
star
2

gail-tf

Tensorflow implementation of generative adversarial imitation learning
Python
200
star
3

CoGAN-tensorflow

Implement Coupled Generative Adversarial Networks in Tensorflow
Python
101
star
4

pytorch-a3c-mujoco

Implement A3C for Mujoco gym envs
Python
73
star
5

awesome-neural-programming

A curated list of awesome neural programming resources
55
star
6

unrolled-gans

PyTorch Implementation of Unrolled Generative Adversarial Networks
Jupyter Notebook
39
star
7

Natural-Language-Object-Retrieval-tensorflow

Implement Natural Language Object Retrieval in tensorflow
Jupyter Notebook
36
star
8

Generative-Model-Survey

My exploration on Generative Model, mainly focus on GAN architecture
31
star
9

env-aware-program-gen

[CVPR2019] Synthesizing Environment-Aware Activities via Activity Sketches
Python
13
star
10

PaperNotes

Paper note for CV, NLP, RL, etc.
5
star
11

Tensorflow-Multi-Threading-Classifier

The project aims at implementing a simple mnist classifer with multi-thread FIFOQueue
Python
4
star
12

VAE-tensorflow

Tensorflow implementation of VAE(variational autoencoder)
Python
3
star
13

template-tf

Python
2
star
14

ManTracker

Use Nitrogen6x and Boe-Bot Car to implement ManTracking task (final project in embedded system lab)
C++
1
star
15

Visualize-Op-tensorflow

Visualize the tensorflow operation, including f(x), and df(x)/dx
Python
1
star
16

VQA_tensorflow

VQA on tensorflow
Python
1
star
17

ComputerVision_hw

the homework of computer vision in National Tsing Hua University
MATLAB
1
star
18

Self-Parking-System

Implement self-parking system on website. All code are written in Javascript.(midterm project in embedded system lab)
JavaScript
1
star