• Stars
    star
    179
  • Rank 214,039 (Top 5 %)
  • Language
    Python
  • License
    GNU Lesser Genera...
  • Created over 7 years ago
  • Updated over 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

TensorFlow implementation of Neural Turing Machines (NTM), with its application on one-shot learning (MANN)

NTM and MANN in TensorFlow

TensorFlow implementation of Neural Turing Machines (NTM), as well as its application on one-shot learning (MANN).

The models are ready to use -- they are encapsulated into classes NTMCell and MANNCell, and the usage is similar to LSTMCell in TensorFlow, so you can apply these models easily in other programs. The sample code is also provided.

You can see my slide for more details about NTM and MANN.

Prerequisites

  • Python 3.5
  • TensorFlow 1.2.0
  • NumPy
  • Pillow (For MANN, prepoccessing of Omniglot dataset)

Implementation of NTM

Paper

Graves, Alex, Greg Wayne, and Ivo Danihelka. "Neural turing machines." arXiv preprint arXiv:1410.5401 (2014).

Usage

Class NTMCell()

The usage of class NTMCell in ntm/ntm_cell.py is similar to tf.contrib.rnn.BasicLSTMCell in TensorFlow. The basic pseudocode is as follows:

import ntm.ntm_cell as ntm_cell
cell = ntm_cell.NTMCell(
    rnn_size=200,           # Size of hidden states of controller 
    memory_size=128,        # Number of memory locations (N)
    memory_vector_dim=20,   # The vector size at each location (M)
    read_head_num=1,        # # of read head
    write_head_num=1,       # # of write head
    addressing_mode='content_and_location', # Address Mechanisms, 'content_and_location' or 'content'
    reuse=False,            # Whether to reuse the variable in the model (if the length of sequence is not fixed, you might need to build more than one model using the same variable, and this will be useful)
)
state = cell.zero_state(batch_size, tf.float32)
output_list = []
for t in range(seq_length):
    output, state = cell(input[i], state)
    output_list.append(output)

Train and Test

To train the model, run:

python copy_task.py

You can specify training options including parameters to the model via flags, such as --model (default is NTM), --batch_size and so on. See code for more detail.

To test the model, run:

python copy_task.py --mode test

You can specify testing options via flags such as --test_seq_length.

Result (Copy task)

Vector of weighting (left: read vector; right: write vector; shift range: 1) Training loss

One-shot Learning with NTM (MANN)

Paper

Santoro, Adam, et al. "One-shot learning with memory-augmented neural networks." arXiv preprint arXiv:1605.06065 (2016).

Usage

Class MANNCell()

The usage of class MANNCell in ntm/mann_cell.py is similar to tf.contrib.rnn.BasicLSTMCell in TensorFlow. The basic pseudocode is as follows:

import ntm.mann_cell as mann_cell
cell = mann_cell.MANNCell(
    rnn_size=200,           # Size of hidden states of controller 
    memory_size=128,        # Number of memory locations (N)
    memory_vector_dim=40,   # The vector size at each location (M)
    head_num=1,             # # of read & write head (in MANN, #(read head) = #(write head))
    gamma=0.95              # Usage decay of the write weights (in eq 20)
    k_strategy='separate'   # In the original MANN paper, query key vector 'k' are used in both reading (eq 17) and writing (eq 23). You can set k_strategy='summary' if you want this way. However, in the NTM paper they are esparated. If you set k_strategy='separate', the controller will generate a new add vector 'a' to replace the query vector 'k' in eq 23.
)
state = cell.zero_state(batch_size, tf.float32)
output_list = []
for t in range(seq_length):
    output, state = cell(input[i], state)
    output_list.append(output)

There is another implementation of MANNCell translated from tristandeleu's Theano version of MANN. You can find it in ntm/mann_cell_2.py and the usage is just the same. The performance is not fully tested but it seems to work fine too.

Train and Test

To train the model, first you need to prepare the Omniglot dataset. Download images_background.zip (964 classes) and images_evaluation.zip (679 classes), then combine them in a new data folder so you have 1623 classes. Your data folder may looks like:

/data
    /Alphabet_of_the_Magi
        /character01
            0709_01.png
            ...
            0709_20.png
        ...
        /character20
    ...
    /ULOG

Then, run:

python one_shot_learning.py

You can specify training options including parameters to the model via flags, such as --model (default is MANN), --batch_size and so on. See code for more detail.

To test the model, run:

python one_shot_learning.py --mode test

You can specify testing options via flags such as --test_batch_num (default: 100), --n_train_classes (default: 1200) and --n_test_classes (default: 423).

Result

Omniglot Classification:

LSTM, five random classes/episode, one-hot vector labels MANN, five random classes/episode, one-hot vector labels
LSTM, fifteen random classes/episode, five-character string labels MANN, fifteen random classes/episode, five-character string labels

Test-set classification accuracies for LSTM and MANN trained on the Omniglot dataset, using one-hot encodings of labels and five classes presented per episode:

Model 1st 2nd 3rd 4th 5th 6th 7th 8th 9th 10th loss
LSTM 0.2333 0.5897 0.6581 0.681 0.7077 0.7156 0.7141 0.7305 0.7281 0.7233 42.6427
MANN 0.3558 0.8881 0.9497 0.9651 0.9734 0.9744 0.9794 0.9798 0.978 0.9755 11.531

More Repositories

1

tensorflow-handbook

įŽ€å•į˛—æš´ TensorFlow 2 | A Concise Handbook of TensorFlow 2 | 一æœŦįŽ€æ˜Žįš„ TensorFlow 2 å…Ĩ门指å¯ŧ教į¨‹
Jupyter Notebook
3,943
star
2

TensorFlow-cn

įŽ€å•į˛—æš´ TensorFlow (1.X) | A Concise Handbook of TensorFlow (1.X) | æ­¤į‰ˆæœŦ不再更新īŧŒæ–°į‰ˆč§ https://tf.wiki
Python
852
star
3

line

TensorFlow implementation of paper "LINE: Large-scale Information Network Embedding" by Jian Tang, et al.
Python
201
star
4

rnn-handwriting-generation

Handwriting generation by RNN with TensorFlow, based on "Generating Sequences With Recurrent Neural Networks" by Alex Graves
Python
98
star
5

rnn-vae

Variational Autoencoder with Recurrent Neural Network based on Google DeepMind's "DRAW: A Recurrent Neural Network For Image Generation"
Python
40
star
6

gnn

TensorFlow implementation of several popular Graph Neural Network layers, wrapped with tf.keras.layers.Layer.
Python
19
star
7

snowkylin.github.io

é›Ēéē’įš„į™žč‰å›­ - Snowkylin's Blog
HTML
13
star
8

data-mining-pku

Instruction of assignment in course "Data Warehousing and Data Mining Technology", Spring and Summer Semester, 2017.
TeX
7
star
9

nncf

TensorFlow implementation of paper "On Sampling Strategies for Neural Network-based Collaborative Filtering" by Chen, Ting, et al.
Python
6
star
10

lemon-public-license

🍋🍋🍋
5
star
11

async_rl

Tensorflow implementation of asyncronous 1-step Q learning in "Asynchronous Methods for Deep Reinforcement Learning" with improvement on weight update process (use minibatch) to speed up training.
Python
4
star
12

baidu-statistics

A simple Python 3 project to give some statistics information about Baidu Search Engine, especially Baijiahao.
Python
3
star
13

rl-flappybird

Using deep reinforcement learning to play flappy bird. Implemented by TensorFlow.
Python
2
star
14

lp-algorithm

A minimized implementation of several popular algorithms for linear programming (Simplex, Interior-Point)
Python
1
star
15

restructuredtext-shortcut

reStructuredText Shortcut Support for Visual Studio Code
TypeScript
1
star