• Stars
    star
    1,166
  • Rank 39,799 (Top 0.8 %)
  • Language
    Python
  • Created about 5 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Single Headed Attention RNN - "Stop thinking with your head"

Single Headed Attention RNN

For full details see the paper Single Headed Attention RNN: Stop Thinking With Your Head.

In summary, "stop thinking with your (attention) head".

  • Obtain strong results on a byte level language modeling dataset (enwik8) in under 24 hours on a single GPU (12GB Titan V)
  • Support long range dependencies (up to 5000 tokens) without increasing compute time or memory usage substantially by using a simpler attention mechanism
  • Avoid the fragile training process required by standard Transformer models such as a long warmup
  • Back off toward a standard LSTM allowing you to drop retained memory states (needed for a Transformer model) if memory becomes a major constraint
  • Provide a smaller model that features only standard components such as the LSTM, single headed attention, and feed-forward modules such that they can easily be productionized using existing optimized tools and exported to various formats (i.e. ONNX)
Model Test BPC Params LSTM Based
Krause mLSTM 1.24 46M βœ”
AWD-LSTM 1.23 44M βœ”
SHA-LSTM 1.07 63M βœ”
12L Transformer-XL 1.06 41M
18L Transformer-XL 1.03 88M
Adaptive Span Transformer (Small) 1.02 38M

Whilst the model is still quite some way away from state of the art (~0.98 bpc) the model is low resource and high efficiency without having yet been optimized to be so. The model was trained in under 24 hours on a single GPU with the Adaptive Span Transformer (small) being the only recent Transformer model to achieve similar levels of training efficiency.

To recreate

Setup

To get started:

  • Retrieve the data with ./getdata.sh
  • Install PyTorch version 1.2+
  • Install Nvidia's AMP
  • Install the minimum trust variant of LAMB from Smerity's PyTorch-LAMB

Training the model

By default the model trains the minimal single headed attention model from the paper, inserting a lone attention mechanism in the second last layer of a four layer LSTM. This takes only half an hour per epoch on a Titan V or V100. If you want slightly better results but a longer training time (an hour per epoch) set use_attn to True for all layers in model.py and decrease batch size until it fits in memory (i.e. 8). Sadly there are no command line options for running the other models - it's manual tinkering. The code is not kind. I'll be performing a re-write in the near future meant for long term academic and industrial use - contact me if you're interested :)

Note: still shaking out bugs from the commands below. We have near third party replication but still a fix or two out. Feel free to run and note any discrepancies! If you fiddle with hyper-parameters (which I've done very little of - it's a treasure chest of opportunity to get a lower than expected BPC as your reward!) do report that too :)

When running the training command below continue until the validation bpc stops improving. Don't worry about letting it run longer as the code will only save the model with the best validation bpc.

python -u main.py --epochs 32 --dropouth 0.1 --dropouti 0.1 --dropout 0.1 --data data/enwik8/ --save ENWIK8.pt --log-interval 10 --seed 5512 --optimizer lamb --bptt 1024 --warmup 800 --lr 2e-3 --emsize 1024 --nhid 4096 --nlayers 4 --batch_size 16

When the training slows down a second pass with a halved learning rate until validation bpc stops improving will get a few more bpc off. A smart learning rate decay is likely the correct way to go here but that's not what I did for my experiments.

python -u main.py --epochs 5 --dropouth 0.1 --dropouti 0.1 --dropout 0.1 --data data/enwik8/ --save ENWIK8.pt --log-interval 10 --seed 5512 --optimizer lamb --bptt 1024 --warmup 800 --lr 2e-3 --emsize 1024 --nhid 4096 --nlayers 4 --batch_size 16 --resume ENWIK8.pt --lr 1e-3 --seed 125

Most of the improvement will happen in the first few epochs of this final command.

The final test bpc should be approximately 1.07 for the full 4 layer SHA-LSTM or 1.08 for the single headed 4 layer SHA-LSTM.

More Repositories

1

keras_snli

Simple Keras model that tackles the Stanford Natural Language Inference (SNLI) corpus using summation and/or recurrent neural networks
Python
264
star
2

trending_arxiv

Track trending arXiv papers on Twitter from within your circle
HTML
169
star
3

bitflipped

Your computer is a cosmic ray detector. Literally.
C
58
star
4

cc-warc-examples

CommonCrawl WARC/WET/WAT examples and processing code for Java + Hadoop
Java
54
star
5

tf-ham

A partial TensorFlow implementation of "Learning Efficient Algorithms with Hierarchical Attentive Memory"
Python
52
star
6

right_whale_hunt

Annotated faces for NOAA Right Whale Recognition Kaggle competition
Python
35
star
7

keras_qa

Keras solution to the bAbI tasks using recurrent neural networks - merged as an example into Keras mainline
Python
34
star
8

search_iclr_2019

HTML
32
star
9

bifurcate-rs

Zero dependency images (of chaos) in Rust
Rust
32
star
10

govarint

A variable length integer compression library for Golang
Go
24
star
11

montelight-cpp

Faster raytracing through importance sampling, rejection sampling, and variance reduction
C++
21
star
12

texting_robots

Texting Robots: A Rust native `robots.txt` parser with thorough unit testing
Rust
20
star
13

Snippets

Useful code snippets that I'd rather not lose
Python
19
star
14

cs205_ga

How deep does Google Analytics go? Efficiently tackling Common Crawl using AWS & MapReduce
Python
17
star
15

gzipstream

gzipstream allows Python to process multi-part gzip files from a streaming source
Python
17
star
16

pubcrawl

*Deprecated* A short and sweet Python web crawler using Redis as the process queue, seen set and Memcache style rate limiter for robots.txt
Python
16
star
17

cc-mrjob

Demonstration of using Python to process the Common Crawl dataset with the mrjob framework
Python
8
star
18

smerity_flask

Smerity.com website generated using (naive) custom Python code, Flask & Frozen-Flask
Less
7
star
19

yolo-cpp

YOLO C++: A crash course for those needing to learn street fighting C++
C++
6
star
20

cc-quick-scripts

Useful scripts for attacking the CommonCrawl dataset and WARC/WET/WAT files
Python
6
star
21

gopagerank

PageRank implemented in Go for large graphs (billions of edges)
Go
5
star
22

Hip-Flask

*Deprecated*
JavaScript
4
star
23

glove-guante

Exploration of Global Vectors for Word Representation (GloVe)
Go
3
star
24

graphx-prank

GraphX P[age]Rank -- PageRank runner for large graphs
Scala
3
star
25

comp3109_assignment1

Nick and Smerity's assignment
Common Lisp
2
star
26

BoxOfPrimes

Fast pseudo-random prime number generator for n bits using the OpenSSL library
C
2
star
27

grimrepo

Automatically create and set remote private Git repositories at BitBucket
Python
2
star
28

texting_robots_cc_test

Texting Robots: Common Crawl `robots.txt` Test
Rust
2
star
29

tableau

Group 2's Tableau app from NCSS 2014
Python
2
star
30

FacebookFriends

Plugin for Vanilla Forums: Shows the real name of any of your Facebook friends on the Vanilla Forum
PHP
1
star
31

kaggle_connectomics

Connectomics: Predicting the directed connections between 1,000 neurons using neural activity time series data
Python
1
star
32

stat183_madness

March Madness
R
1
star
33

vimfiles

Vim Script
1
star
34

smerity.github.com

HTML
1
star
35

rosettafight

Rosetta Fight: Quick Lookup and Comparison on Rosetta Code Languages
Python
1
star
36

real_world_algorithms

Notes from the Real World Algorithms course at Sydney University
1
star
37

fknn

Fast KNN for large scale multiclass problems
C++
1
star
38

lockoutbot

Example for NCSS - making lockoutbot <3
Python
1
star
39

cs281_edge_estimator

CS281 Final Project: Estimate edge weights given multiple page view samples
Python
1
star
40

gogorobot

Exploratory robots.txt crawler written in Go
Go
1
star