• Stars
    star
    1,900
  • Rank 23,431 (Top 0.5 %)
  • Language
    Python
  • License
    BSD 3-Clause "New...
  • Created over 6 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

LSTM and QRNN Language Model Toolkit for PyTorch

LSTM and QRNN Language Model Toolkit

This repository contains the code used for two Salesforce Research papers:

The model comes with instructions to train:

  • word level language models over the Penn Treebank (PTB), WikiText-2 (WT2), and WikiText-103 (WT103) datasets

  • character level language models over the Penn Treebank (PTBC) and Hutter Prize dataset (enwik8)

The model can be composed of an LSTM or a Quasi-Recurrent Neural Network (QRNN) which is two or more times faster than the cuDNN LSTM in this setup while achieving equivalent or better accuracy.

  • Install PyTorch 0.4
  • Run getdata.sh to acquire the Penn Treebank and WikiText-2 datasets
  • Train the base model using main.py
  • (Optionally) Finetune the model using finetune.py
  • (Optionally) Apply the continuous cache pointer to the finetuned model using pointer.py

If you use this code or our results in your research, please cite as appropriate:

@article{merityRegOpt,
  title={{Regularizing and Optimizing LSTM Language Models}},
  author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard},
  journal={arXiv preprint arXiv:1708.02182},
  year={2017}
}
@article{merityAnalysis,
  title={{An Analysis of Neural Language Modeling at Multiple Scales}},
  author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard},
  journal={arXiv preprint arXiv:1803.08240},
  year={2018}
}

Update (June/13/2018)

The codebase is now PyTorch 0.4 compatible for most use cases (a big shoutout to https://github.com/shawntan for a fairly comprehensive PR #43). Mild readjustments to hyperparameters may be necessary to obtain quoted performance. If you desire exact reproducibility (or wish to run on PyTorch 0.3 or lower), we suggest using an older commit of this repository. We are still working on pointer, finetune and generate functionalities.

Software Requirements

Python 3 and PyTorch 0.4 are required for the current codebase.

Included below are hyper parameters to get equivalent or better results to those included in the original paper.

If you need to use an earlier version of the codebase, the original code and hyper parameters accessible at the PyTorch==0.1.12 release, with Python 3 and PyTorch 0.1.12 are required. If you are using Anaconda, installation of PyTorch 0.1.12 can be achieved via: conda install pytorch=0.1.12 -c soumith.

Experiments

The codebase was modified during the writing of the paper, preventing exact reproduction due to minor differences in random seeds or similar. We have also seen exact reproduction numbers change when changing underlying GPU. The guide below produces results largely similar to the numbers reported.

For data setup, run ./getdata.sh. This script collects the Mikolov pre-processed Penn Treebank and the WikiText-2 datasets and places them in the data directory.

Next, decide whether to use the QRNN or the LSTM as the underlying recurrent neural network model. The QRNN is many times faster than even Nvidia's cuDNN optimized LSTM (and dozens of times faster than a naive LSTM implementation) yet achieves similar or better results than the LSTM for many word level datasets. At the time of writing, the QRNN models use the same number of parameters and are slightly deeper networks but are two to four times faster per epoch and require less epochs to converge.

The QRNN model uses a QRNN with convolutional size 2 for the first layer, allowing the model to view discrete natural language inputs (i.e. "New York"), while all other layers use a convolutional size of 1.

Finetuning Note: Fine-tuning modifies the original saved model model.pt file - if you wish to keep the original weights you must copy the file.

Pointer note: BPTT just changes the length of the sequence pushed onto the GPU but won't impact the final result.

Character level enwik8 with LSTM

  • python -u main.py --epochs 50 --nlayers 3 --emsize 400 --nhid 1840 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.1 --dropouti 0.1 --dropout 0.4 --wdrop 0.2 --wdecay 1.2e-6 --bptt 200 --batch_size 128 --optimizer adam --lr 1e-3 --data data/enwik8 --save ENWIK8.pt --when 25 35

Character level Penn Treebank (PTB) with LSTM

  • python -u main.py --epochs 500 --nlayers 3 --emsize 200 --nhid 1000 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.25 --dropouti 0.1 --dropout 0.1 --wdrop 0.5 --wdecay 1.2e-6 --bptt 150 --batch_size 128 --optimizer adam --lr 2e-3 --data data/pennchar --save PTBC.pt --when 300 400

Word level WikiText-103 (WT103) with QRNN

  • python -u main.py --epochs 14 --nlayers 4 --emsize 400 --nhid 2500 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.1 --dropouti 0.1 --dropout 0.1 --wdrop 0 --wdecay 0 --bptt 140 --batch_size 60 --optimizer adam --lr 1e-3 --data data/wikitext-103 --save WT103.12hr.QRNN.pt --when 12 --model QRNN

Word level Penn Treebank (PTB) with LSTM

The instruction below trains a PTB model that without finetuning achieves perplexities of approximately 61.2 / 58.8 (validation / testing), with finetuning achieves perplexities of approximately 58.8 / 56.5, and with the continuous cache pointer augmentation achieves perplexities of approximately 53.2 / 52.5.

  • python main.py --batch_size 20 --data data/penn --dropouti 0.4 --dropouth 0.25 --seed 141 --epoch 500 --save PTB.pt
  • python finetune.py --batch_size 20 --data data/penn --dropouti 0.4 --dropouth 0.25 --seed 141 --epoch 500 --save PTB.pt
  • python pointer.py --data data/penn --save PTB.pt --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000

Word level Penn Treebank (PTB) with QRNN

The instruction below trains a QRNN model that without finetuning achieves perplexities of approximately 60.6 / 58.3 (validation / testing), with finetuning achieves perplexities of approximately 59.1 / 56.7, and with the continuous cache pointer augmentation achieves perplexities of approximately 53.4 / 52.6.

  • python -u main.py --model QRNN --batch_size 20 --clip 0.2 --wdrop 0.1 --nhid 1550 --nlayers 4 --emsize 400 --dropouth 0.3 --seed 9001 --dropouti 0.4 --epochs 550 --save PTB.pt
  • python -u finetune.py --model QRNN --batch_size 20 --clip 0.2 --wdrop 0.1 --nhid 1550 --nlayers 4 --emsize 400 --dropouth 0.3 --seed 404 --dropouti 0.4 --epochs 300 --save PTB.pt
  • python pointer.py --model QRNN --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000 --save PTB.pt

Word level WikiText-2 (WT2) with LSTM

The instruction below trains a PTB model that without finetuning achieves perplexities of approximately 68.7 / 65.6 (validation / testing), with finetuning achieves perplexities of approximately 67.4 / 64.7, and with the continuous cache pointer augmentation achieves perplexities of approximately 52.2 / 50.6.

  • python main.py --epochs 750 --data data/wikitext-2 --save WT2.pt --dropouth 0.2 --seed 1882
  • python finetune.py --epochs 750 --data data/wikitext-2 --save WT2.pt --dropouth 0.2 --seed 1882
  • python pointer.py --save WT2.pt --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2

Word level WikiText-2 (WT2) with QRNN

The instruction below will a QRNN model that without finetuning achieves perplexities of approximately 69.3 / 66.8 (validation / testing), with finetuning achieves perplexities of approximately 68.5 / 65.9, and with the continuous cache pointer augmentation achieves perplexities of approximately 53.6 / 52.1. Better numbers are likely achievable but the hyper parameters have not been extensively searched. These hyper parameters should serve as a good starting point however.

  • python -u main.py --epochs 500 --data data/wikitext-2 --clip 0.25 --dropouti 0.4 --dropouth 0.2 --nhid 1550 --nlayers 4 --seed 4002 --model QRNN --wdrop 0.1 --batch_size 40 --save WT2.pt
  • python finetune.py --epochs 500 --data data/wikitext-2 --clip 0.25 --dropouti 0.4 --dropouth 0.2 --nhid 1550 --nlayers 4 --seed 4002 --model QRNN --wdrop 0.1 --batch_size 40 --save WT2.pt
  • python -u pointer.py --save WT2.pt --model QRNN --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2

Speed

For speed regarding character-level PTB and enwik8 or word-level WikiText-103, refer to the relevant paper.

The default speeds for the models during training on an NVIDIA Quadro GP100:

  • Penn Treebank (batch size 20): LSTM takes 65 seconds per epoch, QRNN takes 28 seconds per epoch
  • WikiText-2 (batch size 20): LSTM takes 180 seconds per epoch, QRNN takes 90 seconds per epoch

The default QRNN models can be far faster than the cuDNN LSTM model, with the speed-ups depending on how much of a bottleneck the RNN is. The majority of the model time above is now spent in softmax or optimization overhead (see PyTorch QRNN discussion on speed).

Speeds are approximately three times slower on a K80. On a K80 or other memory cards with less memory you may wish to enable the cap on the maximum sampled sequence length to prevent out-of-memory (OOM) errors, especially for WikiText-2.

If speed is a major issue, SGD converges more quickly than our non-monotonically triggered variant of ASGD though achieves a worse overall perplexity.

Details of the QRNN optimization

For full details, refer to the PyTorch QRNN repository.

Details of the LSTM optimization

All the augmentations to the LSTM, including our variant of DropConnect (Wan et al. 2013) termed weight dropping which adds recurrent dropout, allow for the use of NVIDIA's cuDNN LSTM implementation. PyTorch will automatically use the cuDNN backend if run on CUDA with cuDNN installed. This ensures the model is fast to train even when convergence may take many hundreds of epochs.

More Repositories

1

LAVIS

LAVIS - A One-stop Library for Language-Vision Intelligence
Jupyter Notebook
8,226
star
2

CodeGen

CodeGen is a family of open-source model for program synthesis. Trained on TPU-v4. Competitive with OpenAI Codex.
Python
4,594
star
3

BLIP

PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
Jupyter Notebook
3,879
star
4

akita

🚀 State Management Tailored-Made for JS Applications
TypeScript
3,442
star
5

Merlion

Merlion: A Machine Learning Framework for Time Series Intelligence
Python
3,232
star
6

ja3

JA3 is a standard for creating SSL client fingerprints in an easy to produce and shareable way.
Python
2,502
star
7

CodeT5

Home of CodeT5: Open Code LLMs for Code Understanding and Generation
Python
2,437
star
8

decaNLP

The Natural Language Decathlon: A Multitask Challenge for NLP
Python
2,301
star
9

TransmogrifAI

TransmogrifAI (pronounced trăns-mŏgˈrə-fī) is an AutoML library for building modular, reusable, strongly typed machine learning workflows on Apache Spark with minimal hand-tuning
Scala
2,227
star
10

policy_sentry

IAM Least Privilege Policy Generator
Python
1,938
star
11

cloudsplaining

Cloudsplaining is an AWS IAM Security Assessment tool that identifies violations of least privilege and generates a risk-prioritized report.
JavaScript
1,865
star
12

ctrl

Conditional Transformer Language Model for Controllable Generation
Python
1,766
star
13

lwc

⚡️ LWC - A Blazing Fast, Enterprise-Grade Web Components Foundation
JavaScript
1,537
star
14

WikiSQL

A large annotated semantic parsing corpus for developing natural language interfaces.
HTML
1,520
star
15

sloop

Kubernetes History Visualization
Go
1,396
star
16

CodeTF

CodeTF: One-stop Transformer Library for State-of-the-art Code LLM
Python
1,375
star
17

ALBEF

Code for ALBEF: a new vision-language pre-training method
Python
1,276
star
18

pytorch-qrnn

PyTorch implementation of the Quasi-Recurrent Neural Network - up to 16 times faster than NVIDIA's cuDNN LSTM
Python
1,255
star
19

ai-economist

Foundation is a flexible, modular, and composable framework to model socio-economic behaviors and dynamics with both agents and governments. This framework can be used in conjunction with reinforcement learning to learn optimal economic policies, as done by the AI Economist (https://www.einstein.ai/the-ai-economist).
Python
964
star
20

jarm

Python
914
star
21

design-system-react

Salesforce Lightning Design System for React
JavaScript
896
star
22

tough-cookie

RFC6265 Cookies and CookieJar for Node.js
TypeScript
858
star
23

reactive-grpc

Reactive stubs for gRPC
Java
814
star
24

OmniXAI

OmniXAI: A Library for eXplainable AI
Jupyter Notebook
782
star
25

xgen

Salesforce open-source LLMs with 8k sequence length.
Python
704
star
26

vulnreport

Open-source pentesting management and automation platform by Salesforce Product Security
HTML
593
star
27

UniControl

Unified Controllable Visual Generation Model
Python
577
star
28

hassh

HASSH is a network fingerprinting standard which can be used to identify specific Client and Server SSH implementations. The fingerprints can be easily stored, searched and shared in the form of a small MD5 fingerprint.
Python
525
star
29

progen

Official release of the ProGen models
Python
518
star
30

Argus

Time series monitoring and alerting platform.
Java
501
star
31

base-components-recipes

A collection of base component recipes for Lightning Web Components on Salesforce Platform
JavaScript
496
star
32

matchbox

Write PyTorch code at the level of individual examples, then run it efficiently on minibatches.
Python
488
star
33

PCL

PyTorch code for "Prototypical Contrastive Learning of Unsupervised Representations"
Python
483
star
34

cove

Python
470
star
35

CodeRL

This is the official code for the paper CodeRL: Mastering Code Generation through Pretrained Models and Deep Reinforcement Learning (NeurIPS22).
Python
465
star
36

DialogStudio

DialogStudio: Towards Richest and Most Diverse Unified Dataset Collection and Instruction-Aware Models for Conversational AI
Python
431
star
37

warp-drive

Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning Framework on a GPU (JMLR 2022)
Python
429
star
38

observable-membrane

A Javascript Membrane implementation using Proxies to observe mutation on an object graph
TypeScript
368
star
39

PyRCA

PyRCA: A Python Machine Learning Library for Root Cause Analysis
Python
367
star
40

DeepTime

PyTorch code for Learning Deep Time-index Models for Time Series Forecasting (ICML 2023)
Python
322
star
41

ULIP

Python
316
star
42

logai

LogAI - An open-source library for log analytics and intelligence
Python
298
star
43

MultiHopKG

Multi-hop knowledge graph reasoning learned via policy gradient with reward shaping and action dropout
Jupyter Notebook
290
star
44

CodeGen2

CodeGen2 models for program synthesis
Python
269
star
45

provis

Official code repository of "BERTology Meets Biology: Interpreting Attention in Protein Language Models."
Python
269
star
46

jaxformer

Minimal library to train LLMs on TPU in JAX with pjit().
Python
255
star
47

EDICT

Jupyter Notebook
247
star
48

causalai

Salesforce CausalAI Library: A Fast and Scalable framework for Causal Analysis of Time Series and Tabular Data
Jupyter Notebook
223
star
49

ETSformer

PyTorch code for ETSformer: Exponential Smoothing Transformers for Time-series Forecasting
Python
221
star
50

themify

👨‍🎨 CSS Themes Made Easy. A robust, opinionated solution to manage themes in your web application
TypeScript
216
star
51

rules_spring

Bazel rule for building Spring Boot apps as a deployable jar
Starlark
215
star
52

simpletod

Official repository for "SimpleTOD: A Simple Language Model for Task-Oriented Dialogue"
Python
212
star
53

TabularSemanticParsing

Translating natural language questions to a structured query language
Jupyter Notebook
210
star
54

grpc-java-contrib

Useful extensions for the grpc-java library
Java
208
star
55

GeDi

GeDi: Generative Discriminator Guided Sequence Generation
Python
207
star
56

aws-allowlister

Automatically compile an AWS Service Control Policy that ONLY allows AWS services that are compliant with your preferred compliance frameworks.
Python
207
star
57

mirus

Mirus is a cross data-center data replication tool for Apache Kafka
Java
200
star
58

generic-sidecar-injector

A generic framework for injecting sidecars and related configuration in Kubernetes using Mutating Webhook Admission Controllers
Go
200
star
59

CoST

PyTorch code for CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting (ICLR 2022)
Python
196
star
60

factCC

Resources for the "Evaluating the Factual Consistency of Abstractive Text Summarization" paper
Python
192
star
61

runway-browser

Interactive visualization framework for Runway models of distributed systems
JavaScript
188
star
62

glad

Global-Locally Self-Attentive Dialogue State Tracker
Python
186
star
63

ALPRO

Align and Prompt: Video-and-Language Pre-training with Entity Prompts
Python
177
star
64

densecap

Jupyter Notebook
176
star
65

cloud-guardrails

Rapidly apply hundreds of security controls in Azure
HCL
174
star
66

booksum

Python
167
star
67

kafka-junit

This library wraps Kafka's embedded test cluster, allowing you to more easily create and run integration tests using JUnit against a "real" kafka server running within the context of your tests. No need to stand up an external kafka cluster!
Java
166
star
68

sfdx-lwc-jest

Run Jest against LWC components in SFDX workspace environment
JavaScript
156
star
69

ctrl-sum

Resources for the "CTRLsum: Towards Generic Controllable Text Summarization" paper
Python
144
star
70

cos-e

Commonsense Explanations Dataset and Code
Python
143
star
71

hierarchicalContrastiveLearning

Python
140
star
72

secure-filters

Anti-XSS Security Filters for EJS and More
JavaScript
138
star
73

metabadger

Prevent SSRF attacks on AWS EC2 via automated upgrades to the more secure Instance Metadata Service v2 (IMDSv2).
Python
129
star
74

dockerfile-image-update

A tool that helps you get security patches for Docker images into production as quickly as possible without breaking things
Java
127
star
75

Converse

Python
125
star
76

refocus

The Go-To Platform for Visualizing Service Health
JavaScript
125
star
77

CoMatch

Code for CoMatch: Semi-supervised Learning with Contrastive Graph Regularization
Python
117
star
78

BOLAA

Python
114
star
79

bazel-eclipse

This repo holds two IDE projects. One is the Eclipse Feature for developing Bazel projects in Eclipse. The Bazel Eclipse Feature supports importing, building, and testing Java projects that are built using the Bazel build system. The other is the Bazel Java Language Server, which is a build integration for IDEs such as VS Code.
Java
108
star
80

botsim

BotSIM - a data-efficient end-to-end Bot SIMulation toolkit for evaluation, diagnosis, and improvement of commercial chatbots
Jupyter Notebook
108
star
81

near-membrane

JavaScript Near Membrane Library that powers Lightning Locker Service
TypeScript
107
star
82

rng-kbqa

Python
105
star
83

MUST

PyTorch code for MUST
Python
103
star
84

fsnet

Python
101
star
85

bro-sysmon

How to Zeek Sysmon Logs!
Zeek
101
star
86

Timbermill

A better logging service
Java
99
star
87

best

🏆 Delightful Benchmarking & Performance Testing
TypeScript
95
star
88

eslint-config-lwc

Opinionated ESLint configurations for LWC projects
JavaScript
93
star
89

craft

CRAFT removes the language barrier to create Kubernetes Operators.
Go
91
star
90

AuditNLG

AuditNLG: Auditing Generative AI Language Modeling for Trustworthiness
Python
90
star
91

online_conformal

Methods for online conformal prediction.
Jupyter Notebook
90
star
92

lobster-pot

Scans every git push to your Github organisations to find unwanted secrets.
Go
88
star
93

violet-conversations

Sophisticated Conversational Applications/Bots
JavaScript
84
star
94

ml4ir

Machine Learning for Information Retrieval
Jupyter Notebook
84
star
95

apex-mockery

Lightweight mocking library in Apex
Apex
83
star
96

fast-influence-functions

Python
80
star
97

MoPro

MoPro: Webly Supervised Learning
Python
79
star
98

TaiChi

Open source library for few shot NLP
Python
79
star
99

helm-starter-istio

An Istio starter template for Helm
Shell
78
star
100

QAConv

This repository maintains the QAConv dataset, a question-answering dataset on informative conversations including business emails, panel discussions, and work channels.
Python
77
star