• Stars
    star
    192
  • Rank 202,019 (Top 4 %)
  • Language
    Python
  • Created over 2 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[EMNLP 2022] Training Language Models with Memory Augmentation https://arxiv.org/abs/2205.12674

TRIME: Training Language Models with Memory Augmentation

This is the repository for our EMNLP2022 paper Training Language Models with Memory Augmentation, by Zexuan Zhong, Tao Lei, and Danqi Chen.

Updates

  • [2022/11/13] We have released the code and pre-trained models for the machine translation experiments here.
  • [2022/10/25] Our paper has been accepted to EMNLP 2022! Please check out the updated version. We have added domain adaptation and stronger results on machine translation and character-level language modeling.
  • [2022/07/31] We have released our training code and pre-trained models.
  • [2022/05/24] We have released the preprint of our TRIME paper on training LMs with memory augmentation.

Quick links

Overview

We propose a new training objective TRIME for language modeling, which aligns model outputs with both token embeddings and in-batch memories. We also devise novel ways for data batching and constructing training memories, so that our models can leverage long-range contexts and external datastore effectively.

Please find more details of this work in our paper.

Setup

Requirements and dependencies

The code is based on the following requirements/dependencies (we specify the version we used in our experiments in brackets):

  • Python (3.7.11)
  • pytorch (1.9.1+cu111)
  • faiss-gpu (1.7.1)
  • numpy (1.21.2)

You can install this project (based on Fairseq) as follow:

pip install --editable .

Datasets

We conduct experiments on the Wikitext-103 and enwik8 datasets. Please use get_data.sh to download and preprocess the datasets.

bash get_data.sh {wikitext-103 | enwik8}

The processed datasets will be stored in data-bin/wikitext-103 and data-bin/enwik8.

Run Pre-Trained Models

We show the examples of running pre-trained models on Wikitext-103 with model size = 247M and segment length = 3072. For other experiments (e.g., with different datasets or models), we refer to run_pretrained_models.md for the scripts on all experimental settings.

TrimeLM (local memory)

TrimeLM uses only the local memory (constructed using tokens in the input). It can be viewed as a lightweight replacement for vanilla langauge models.

# download the pre-trained TrimeLM
mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime.zip;
unzip wiki103-247M-trime.zip; rm -f wiki103-247M-trime.zip
cd ..

# run evaluation
python eval_lm-trime.py data-bin/wikitext-103 \
    --path pretrained_models/wiki103-247M-trime/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 --context-window 2560 \
    --softmax-batch 1024 --gen-subset valid --fp16 \
    --max-sentences 1 --knn-keytype last_ffn_input \
    --use-local --softmax-temp 1.17

# the following output is expected:
# Loss (base 2): 4.0962, Perplexity: 17.10

Arguments:

  • --use-local specifies using local memory.
  • --softmax-temp specifies the temperature term used when computing the loss.

TrimeLM_long (local + long-term memory)

TrimeLM_long uses local memory and long-term memory during inference. The model is able to leverage long contexts, although it is trained with shorter ones.

# download the pre-trained TRIME_long
mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime_long.zip;
unzip wiki103-247M-trime_long.zip; rm -f wiki103-247M-trime_long.zip
cd ..

# run evaluation
python eval_lm-trime.py data-bin/wikitext-103 \
    --path pretrained_models/wiki103-247M-trime_long/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 --context-window 2560 \
    --softmax-batch 1024 --gen-subset valid --fp16 \
    --max-sentences 1 --knn-keytype last_ffn_input \
    --use-local --use-long --mem-size 12288 --softmax-temp 1.22

# the following output is expected:
# Loss (base 2): 4.0879, Perplexity: 17.01

Arguments:

  • --use-long specifies using long-term memory.
  • --mem-size specifies the size of local + long-term memory.

TrimeLM_ext (local + long-term + external memory)

TrimeLM_ext uses local memory, long-term memory, and external memory. During inference, we run the model on the training set to build the external memory and use Faiss library to build index for retrieving top-K nearest neighbors the external memory. We also calibrate a separated distribution over the memory and interpolate the output distribution and the memory distribution, similarly to kNN-LM (see details in the paper).

We first download the pre-trained TrimeLM_ext:

mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime_ext.zip;
unzip wiki103-247M-trime_ext.zip; rm -f wiki103-247M-trime_ext.zip
cd ..

Then, we generate the external memory (keys and values) using the training set and then build the Faiss index:

MODEL_PATH=pretrained_models/wiki103-247M-trime_ext

# generate the external memory (keys and values) using the training set
python eval_lm.py data-bin/wikitext-103 \
    --path ${MODEL_PATH}/checkpoint_best.pt \
    --sample-break-mode none --max-tokens 3072 \
    --softmax-batch 1024 --gen-subset train \
    --context-window 2560 --tokens-per-sample 512 \
    --dstore-mmap ${MODEL_PATH}/dstore --knn-keytype last_ffn_input \
    --dstore-size 103224461 \
    --save-knnlm-dstore --fp16 --dstore-fp16


# build Faiss index
python build_dstore.py \
    --dstore_mmap ${MODEL_PATH}/dstore \
    --dstore_size 103224461 --dimension 1024 \
    --faiss_index ${MODEL_PATH}/knn.index \
    --num_keys_to_add_at_a_time 500000 \
    --starting_point 0  --dstore_fp16  --dist ip

Now, we are ready to evaluate the model:

MODEL_PATH=pretrained_models/wiki103-247M-trime_ext

python eval_lm-trime.py data-bin/wikitext-103 \
    --path ${MODEL_PATH}/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 --context-window 2560 \
    --softmax-batch 1024 --gen-subset valid --fp16 \
    --max-sentences 1 --knn-keytype last_ffn_input \
    --use-local --use-long --mem-size 12288 --softmax-temp 1.25 \
    --use-external --dstore-filename ${MODEL_PATH}/dstore --indexfile ${MODEL_PATH}/knn.index.ip \
    --probe 32 --dstore-fp16 --faiss-metric-type ip --no-load-keys --k 1024 \
    --use-interp --interp-temp 10.5 --lmbda 0.3 

# the following output is expected:
# Loss (base 2): 3.9580, Perplexity: 15.54

Arguments:

  • --use-external specifies using external memory.
  • --dstore-filename and indexfile specify the datastore and the Faiss index paths.
  • --use-interp specifies using a linear interpolation between two distributions to calibrate final probablity.
  • --lmbda and --interp-temp specify the temerpature term and the weight when using the linear interpolation.

Performance of pre-trained models

We list the performance of the released pre-trained models on Wikitext-103 and enwik8, as well as their download links.

Dataset Model Dev Test Hyper-parameters
Wikitext-103 TrimeLM
(247M, L=3072)
17.10 17.76 --softmax-temp 1.17
Wikitext-103 TrimeLM_long
(247M, L=3072)
17.01 17.64 --softmax-temp 1.22 --mem-size 12288
Wikitext-103 TrimeLM_ext
(247M, L=3072)
15.54 15.46 --softmax-temp 1.25 --mem-size 12288 --interp-temp 10.5 --lmbda 0.3
Wikitext-103 TrimeLM
(150M, L=150)
24.45 25.61 --softmax-temp 1.03
Wikitext-103 TrimeLM_long
(150M, L=150)
21.76 22.62 --softmax-temp 1.07 --mem-size 15000
enwik8 TrimeLM
(38M, L=512)
1.14 1.12 --softmax-temp 1.05
enwik8 TrimeLM_long
(38M, L=512)
1.08 1.05 --softmax-temp 1.10 --mem-size 24576

Train TrimeLM

Trime loss functions

We follow Fairseq's training recipe (e.g., optimizer, learning rate, batch size) to train TrimeLM. Differently, we use our own loss functions (specified by --criterion) and data batching methods.

We trained three varieties of TrimeLM by using different data batching and memory construction methods.

  • TrimeLM is trained with --criterion trime_loss
    • During training, we use previous tokens in the same segment to construct the working memory.
  • TrimeLM_long is trained with either --criterion trime_long_loss or --criterion trime_long_loss_same_device
    • We batch consecutive segments into one mini-batch; argument --keep-order is needed to batch consecutive segments.
    • During training, we use all the tokens int previous segments and previous tokens in the same segment to construct the working memory.
    • When using trime_long_loss, we need to specify the memory size through --train-mem-size (num. of consecutive segments will be args.train_mem_size/args.tokens_per_sample).
    • When using trime_long_loss_same_device, we assume all consecutive segments are loaded in the same GPU device (equivalently args.mem_size == args.max_tokens). Using trime_long_loss_same_device is more efficient than using trime_long_loss, as it requires less cross-GPU communications.
  • TrimeLM_ext is trained with --criterion trime_ext_loss
    • We batch segments that have high BM25 scores into one mini-batch. The results of BM25 batching is specified by --predefined-batches.
    • During training, we use all previous tokens in the same segment and all tokens in other segments to construct working memory.
    • With a probability p, we disable the local memory (i.e., only using tokens from other segments to construct memory). The probablity p is specified by --cross-sent-ratio

Training scripts

Here is an example of training a TrimeLM_ext model. You can find all training scripts we used in our experiments in train_scripts.

We train our models on 4 NVIDIA RTX3090 GPUs.

# download the results of bm25 batching
wget https://nlp.cs.princeton.edu/projects/trime/bm25_batch/wiki103-l3072-batches.json -P data-bin/wikitext-103/

python train.py --task language_modeling data-bin/wikitext-103 \
    --save-dir output/wiki103-247M-trime_ext \
    --arch transformer_lm_wiki103 \
    --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
    --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
    --criterion trime_ext_loss --max-tokens 3072 --update-freq 6 --tokens-per-sample 3072 --seed 1 \
    --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --knn-keytype last_ffn_input --fp16 \
    --ce-warmup-epoch 9 --cross-sent-ratio 0.9 \
    --predefined-batches data-bin/wikitext-103/wiki103-l3072-batches.json

Important arguments:

  • --arch specifies the model architecture. In our experiments, we have been using the following architectures.
    • transformer_lm_wiki103 (a 247M model for wikitext-103)
    • transformer_lm_wiki103_150M (a 150M model for wikitext-103)
    • transformer_lm_enwik8 (a 38M model for enwik8)
  • --criterion specifies the function to compute loss values. See description above about which functions we support.
  • --tokens-per-sample specifies the segment length.
  • --max-tokens specifies the number of tokens to be loaded in each GPU.
  • --update-freq specifies the gradient-accumulation steps.
  • --ce-warmup-epoch specifies how many epoches the original CE loss is used at the beginning to warm-up the training.
  • --cross-sent-ratio specifies the probability p to disable the local memory.
  • --predefined-batches specifies the file path of the predefined batches (we use BM25 to batch segments).

BM25 batching

When training the TrimeLM_ext model with --criterion trime_ext_loss, we use BM25 scores to batch training data.

We use the Pyserini library to build BM25 index. The library can be installed via pip.

pip install pyserini

We first save all the segments from the training set into a .json file.

mkdir -p bm25/wiki103-l3072/segments
CUDA_VISIBLE_DEVICES=0 python train.py --task language_modeling \
    data-bin/wikitext-103 \
    --max-tokens 6144 --tokens-per-sample 3072 \
    --arch transformer_lm_wiki103 \
    --output-segments-to-file bm25/wiki103-l3072/segments/segments.json

# Modify --tokens-per-sample for different segment lengths

Then, we build the BM25 index using Pyserini.

python -m pyserini.index.lucene \
  --collection JsonCollection \
  --input bm25/wiki103-l3072/segments \
  --index bm25/wiki103-l3072/bm25_index \
  --generator DefaultLuceneDocumentGenerator --threads 1 \
  --storePositions --storeDocvectors --storeRaw

Next, for each training segment, we search the similar segments using the BM25 index we built above.

python bm25_search.py \
    --index_path bm25/wiki103-l3072/bm25_index/ \
    --segments_path bm25/wiki103-l3072/segments/segments.json \
    --results_path bm25/wiki103-l3072/bm25_results

# Use --num_shards and --shard_id; you can parallel the computation of NN search (e.g., --num_shards 20).

Finally, based on the retrieval results, we create batches by group similar segments.

python bm25_make_batches.py \
    --results_path bm25/wiki103-l3072/bm25_results \
    --batch_file data-bin/wikitext-103/wiki103-l3072-batches.json

The output file wiki103-l3072-batches.json contains a list of indices of training segments and adjacent segments are likely to be similar.

The batch file wiki103-l3072-batches.json can be used during the training of TrimeLM_ext, with the argument --predefined-batches. During training, we simply get training batches by taking sub-lists sequencitally from the file.

Machine Translation

For machine translation code and experiments, please check out the subdirectory.

Bugs or Questions?

If you have any questions related to the code or the paper, or you encounter any problems when using the code, feel free to email Zexuan Zhong ([email protected]) or open an issue. Please try to specify the problem with details so we can help you better and quicker!

Citation

If you use our code in your research, please cite our work:

@inproceedings{zhong2022training,
   title={Training Language Models with Memory Augmentation},
   author={Zhong, Zexuan and Lei, Tao and Chen, Danqi},
   booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
   year={2022}
}

Acknowledgement

Our repo is based on the fairseq, knnlm, and adaptive-knn-mt projects. We thank the authors for open-sourcing the great code!

More Repositories

1

SWE-agent

SWE-agent takes a GitHub issue and tries to automatically fix it, using GPT-4, or your LM of choice. It can also be employed for offensive cybersecurity or competitive coding challenges. [NeurIPS 2024]
Python
13,504
star
2

tree-of-thought-llm

[NeurIPS 2023] Tree of Thoughts: Deliberate Problem Solving with Large Language Models
Python
4,726
star
3

SimCSE

[EMNLP 2021] SimCSE: Simple Contrastive Learning of Sentence Embeddings https://arxiv.org/abs/2104.08821
Python
3,381
star
4

SWE-bench

[ICLR 2024] SWE-Bench: Can Language Models Resolve Real-world Github Issues?
Python
1,846
star
5

MeZO

[NeurIPS 2023] MeZO: Fine-Tuning Language Models with Just Forward Passes. https://arxiv.org/abs/2305.17333
Python
1,031
star
6

PURE

[NAACL 2021] A Frustratingly Easy Approach for Entity and Relation Extraction https://arxiv.org/abs/2010.12812
Python
788
star
7

LM-BFF

[ACL 2021] LM-BFF: Better Few-shot Fine-tuning of Language Models https://arxiv.org/abs/2012.15723
Python
714
star
8

SimPO

SimPO: Simple Preference Optimization with a Reference-Free Reward
Python
672
star
9

DensePhrases

[ACL 2021] Learning Dense Representations of Phrases at Scale; EMNLP'2021: Phrase Retrieval Learns Passage Retrieval, Too https://arxiv.org/abs/2012.12624
Python
605
star
10

LLM-Shearing

[ICLR 2024] Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning
Python
546
star
11

ALCE

[EMNLP 2023] Enabling Large Language Models to Generate Text with Citations. Paper: https://arxiv.org/abs/2305.14627
Python
450
star
12

LESS

[ICML 2024] LESS: Selecting Influential Data for Targeted Instruction Tuning
Jupyter Notebook
354
star
13

AutoCompressors

[EMNLP 2023] Adapting Language Models to Compress Long Contexts
Python
273
star
14

WebShop

[NeurIPS 2022] ๐Ÿ›’WebShop: Towards Scalable Real-World Web Interaction with Grounded Language Agents
Python
264
star
15

intercode

[NeurIPS 2023 D&B] Code repository for InterCode benchmark https://arxiv.org/abs/2306.14898
Python
191
star
16

CoFiPruning

[ACL 2022] Structured Pruning Learns Compact and Accurate Models https://arxiv.org/abs/2204.00408
Python
188
star
17

OptiPrompt

[NAACL 2021] Factual Probing Is [MASK]: Learning vs. Learning to Recall https://arxiv.org/abs/2104.05240
Python
167
star
18

TransformerPrograms

[NeurIPS 2023] Learning Transformer Programs
Python
157
star
19

EntityQuestions

EMNLP'2021: Simple Entity-centric Questions Challenge Dense Retrievers https://arxiv.org/abs/2109.08535
Python
139
star
20

QuRating

[ICML 2024] Selecting High-Quality Data for Training Language Models
Python
137
star
21

CEPE

[ACL 2024] Long-Context Language Modeling with Parallel Encodings
Python
135
star
22

DinkyTrain

Princeton NLP's pre-training library based on fairseq with DeepSpeed kernel integration ๐Ÿšƒ
Python
111
star
23

LLMBar

[ICLR 2024] Evaluating Large Language Models at Evaluating Instruction Following
Python
108
star
24

MQuAKE

[EMNLP 2023] MQuAKE: Assessing Knowledge Editing in Language Models via Multi-Hop Questions
Jupyter Notebook
97
star
25

USACO

Can Language Models Solve Olympiad Programming?
Python
96
star
26

ProLong

Homepage for ProLong (Princeton long-context language models) and paper "How to Train Long-Context Language Models (Effectively)"
Python
82
star
27

NLProofS

EMNLP 2022: Generating Natural Language Proofs with Verifier-Guided Search https://arxiv.org/abs/2205.12443
Python
81
star
28

CharXiv

[NeurIPS 2024] CharXiv: Charting Gaps in Realistic Chart Understanding in Multimodal LLMs
Python
72
star
29

MADE

EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering
Python
70
star
30

LM-Kernel-FT

A Kernel-Based View of Language Model Fine-Tuning https://arxiv.org/abs/2210.05643
Python
68
star
31

c-sts

[EMNLP 2023] C-STS: Conditional Semantic Textual Similarity
Python
66
star
32

calm-textgame

[EMNLP 2020] Keep CALM and Explore: Language Models for Action Generation in Text-based Games
Python
65
star
33

DataMUX

[NeurIPS 2022] DataMUX: Data Multiplexing for Neural Networks
Jupyter Notebook
58
star
34

ShortcutGrammar

EMNLP 2022: Finding Dataset Shortcuts with Grammar Induction https://arxiv.org/abs/2210.11560
Jupyter Notebook
57
star
35

LitSearch

A Retrieval Benchmark for Scientific Literature Search
Python
54
star
36

Collie

[ICLR 2024] COLLIE: Systematic Construction of Constrained Text Generation Tasks
Jupyter Notebook
52
star
37

EvalConvQA

[ACL 2022] Ditch the Gold Standard: Re-evaluating Conversational Question Answering
Python
45
star
38

HELMET

The HELMET Benchmark
Python
42
star
39

MABEL

EMNLP 2022: "MABEL: Attenuating Gender Bias using Textual Entailment Data" https://arxiv.org/abs/2210.14975
Python
37
star
40

LM-Science-Tutor

Python
34
star
41

rationale-robustness

NAACL 2022: Can Rationalization Improve Robustness? https://arxiv.org/abs/2204.11790
Python
26
star
42

PTP

Improving Language Understanding from Screenshots. Paper: https://arxiv.org/abs/2402.14073
Python
25
star
43

corpus-poisoning

[EMNLP 2023] Poisoning Retrieval Corpora by Injecting Adversarial Passages https://arxiv.org/abs/2310.19156
Python
25
star
44

InstructEval

[NAACL 2024 Findings] Evaluation suite for the systematic evaluation of instruction selection methods.
Jupyter Notebook
23
star
45

Edge-Pruning

Code and data for the paper "Finding Transformer Circuits with Edge Pruning".
Python
22
star
46

WhatICLLearns

[ACL 2023 Findings] What In-Context Learning โ€œLearnsโ€ In-Context: Disentangling Task Recognition and Task Learning
Python
21
star
47

Cognac

Repo for paper: Controllable Text Generation with Language Constraints
Python
19
star
48

lwm

We develop world models that can be adapted with natural language. Intergrating these models into artificial agents allows humans to effectively control these agents through verbal communication.
Python
18
star
49

ELIZA-Transformer

Representing Rule-based Chatbots with Transformers
Python
18
star
50

semsup

Semantic Supervision: Enabling Generalization over Output Spaces
Python
16
star
51

benign-data-breaks-safety

Python
16
star
52

SRL-NLC

Safe Reinforcement Learning with Natural Language Constraints
14
star
53

datamux-pretraining

MUX-PLMs: Pretraining LMs with Data Multiplexing
Python
14
star
54

XTX

[ICLR 2022 Spotlight] Multi-Stage Episodic Control for Strategic Exploration in Text Games
Python
13
star
55

MultilingualAnalysis

Repository for the paper titled: "When is BERT Multilingual? Isolating Crucial Ingredients for Cross-lingual Transfer"
Python
13
star
56

dyck-transformer

[ACL 2021] Self-Attention Networks Can Process Bounded Hierarchical Languages
Python
12
star
57

blindfold-textgame

[NAACL 2021] Reading and Acting while Blindfolded: The Need for Semantics in Text Game Agents
Python
12
star
58

align-mlm

Python
11
star
59

metric-wsd

NAACL'2021: Non-Parametric Few-Shot Learning for Word Sense Disambiguation
Python
10
star
60

semsup-xc

SemSup-XC: Semantic Supervision for Extreme Classification
Jupyter Notebook
10
star
61

Heuristic-Core

[ACL 2024] The Heuristic Core: Understanding Subnetwork Generalization in Pretrained Language Models - https://arxiv.org/abs/2403.03942
Python
9
star
62

CopyCat

Python
9
star
63

NegotiationToM

Code release for Improving Dialog Systems for Negotiation with Personality Modeling.
Python
7
star
64

CARETS

Python
6
star
65

SPARTAN

SPARTAN: Sparse Hierarchical Memory for Parameter-Efficient Transformers
Python
5
star
66

il-scaling-in-games

Official code repo of "Scaling Laws for Imitation Learning in Single-Agent Games"
Python
5
star
67

attribute-tagging

[LaReL 2022] Towards an Enhanced, Faithful, and Adaptable Web Interaction Environment
Python
4
star
68

MoQA

Python
3
star