• Stars
    star
    189
  • Rank 203,375 (Top 5 %)
  • Language
    Python
  • Created over 2 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[EMNLP 2022] Training Language Models with Memory Augmentation https://arxiv.org/abs/2205.12674

TRIME: Training Language Models with Memory Augmentation

This is the repository for our EMNLP2022 paper Training Language Models with Memory Augmentation, by Zexuan Zhong, Tao Lei, and Danqi Chen.

Updates

  • [2022/11/13] We have released the code and pre-trained models for the machine translation experiments here.
  • [2022/10/25] Our paper has been accepted to EMNLP 2022! Please check out the updated version. We have added domain adaptation and stronger results on machine translation and character-level language modeling.
  • [2022/07/31] We have released our training code and pre-trained models.
  • [2022/05/24] We have released the preprint of our TRIME paper on training LMs with memory augmentation.

Quick links

Overview

We propose a new training objective TRIME for language modeling, which aligns model outputs with both token embeddings and in-batch memories. We also devise novel ways for data batching and constructing training memories, so that our models can leverage long-range contexts and external datastore effectively.

Please find more details of this work in our paper.

Setup

Requirements and dependencies

The code is based on the following requirements/dependencies (we specify the version we used in our experiments in brackets):

  • Python (3.7.11)
  • pytorch (1.9.1+cu111)
  • faiss-gpu (1.7.1)
  • numpy (1.21.2)

You can install this project (based on Fairseq) as follow:

pip install --editable .

Datasets

We conduct experiments on the Wikitext-103 and enwik8 datasets. Please use get_data.sh to download and preprocess the datasets.

bash get_data.sh {wikitext-103 | enwik8}

The processed datasets will be stored in data-bin/wikitext-103 and data-bin/enwik8.

Run Pre-Trained Models

We show the examples of running pre-trained models on Wikitext-103 with model size = 247M and segment length = 3072. For other experiments (e.g., with different datasets or models), we refer to run_pretrained_models.md for the scripts on all experimental settings.

TrimeLM (local memory)

TrimeLM uses only the local memory (constructed using tokens in the input). It can be viewed as a lightweight replacement for vanilla langauge models.

# download the pre-trained TrimeLM
mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime.zip;
unzip wiki103-247M-trime.zip; rm -f wiki103-247M-trime.zip
cd ..

# run evaluation
python eval_lm-trime.py data-bin/wikitext-103 \
    --path pretrained_models/wiki103-247M-trime/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 --context-window 2560 \
    --softmax-batch 1024 --gen-subset valid --fp16 \
    --max-sentences 1 --knn-keytype last_ffn_input \
    --use-local --softmax-temp 1.17

# the following output is expected:
# Loss (base 2): 4.0962, Perplexity: 17.10

Arguments:

  • --use-local specifies using local memory.
  • --softmax-temp specifies the temperature term used when computing the loss.

TrimeLM_long (local + long-term memory)

TrimeLM_long uses local memory and long-term memory during inference. The model is able to leverage long contexts, although it is trained with shorter ones.

# download the pre-trained TRIME_long
mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime_long.zip;
unzip wiki103-247M-trime_long.zip; rm -f wiki103-247M-trime_long.zip
cd ..

# run evaluation
python eval_lm-trime.py data-bin/wikitext-103 \
    --path pretrained_models/wiki103-247M-trime_long/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 --context-window 2560 \
    --softmax-batch 1024 --gen-subset valid --fp16 \
    --max-sentences 1 --knn-keytype last_ffn_input \
    --use-local --use-long --mem-size 12288 --softmax-temp 1.22

# the following output is expected:
# Loss (base 2): 4.0879, Perplexity: 17.01

Arguments:

  • --use-long specifies using long-term memory.
  • --mem-size specifies the size of local + long-term memory.

TrimeLM_ext (local + long-term + external memory)

TrimeLM_ext uses local memory, long-term memory, and external memory. During inference, we run the model on the training set to build the external memory and use Faiss library to build index for retrieving top-K nearest neighbors the external memory. We also calibrate a separated distribution over the memory and interpolate the output distribution and the memory distribution, similarly to kNN-LM (see details in the paper).

We first download the pre-trained TrimeLM_ext:

mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime_ext.zip;
unzip wiki103-247M-trime_ext.zip; rm -f wiki103-247M-trime_ext.zip
cd ..

Then, we generate the external memory (keys and values) using the training set and then build the Faiss index:

MODEL_PATH=pretrained_models/wiki103-247M-trime_ext

# generate the external memory (keys and values) using the training set
python eval_lm.py data-bin/wikitext-103 \
    --path ${MODEL_PATH}/checkpoint_best.pt \
    --sample-break-mode none --max-tokens 3072 \
    --softmax-batch 1024 --gen-subset train \
    --context-window 2560 --tokens-per-sample 512 \
    --dstore-mmap ${MODEL_PATH}/dstore --knn-keytype last_ffn_input \
    --dstore-size 103224461 \
    --save-knnlm-dstore --fp16 --dstore-fp16


# build Faiss index
python build_dstore.py \
    --dstore_mmap ${MODEL_PATH}/dstore \
    --dstore_size 103224461 --dimension 1024 \
    --faiss_index ${MODEL_PATH}/knn.index \
    --num_keys_to_add_at_a_time 500000 \
    --starting_point 0  --dstore_fp16  --dist ip

Now, we are ready to evaluate the model:

MODEL_PATH=pretrained_models/wiki103-247M-trime_ext

python eval_lm-trime.py data-bin/wikitext-103 \
    --path ${MODEL_PATH}/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 --context-window 2560 \
    --softmax-batch 1024 --gen-subset valid --fp16 \
    --max-sentences 1 --knn-keytype last_ffn_input \
    --use-local --use-long --mem-size 12288 --softmax-temp 1.25 \
    --use-external --dstore-filename ${MODEL_PATH}/dstore --indexfile ${MODEL_PATH}/knn.index.ip \
    --probe 32 --dstore-fp16 --faiss-metric-type ip --no-load-keys --k 1024 \
    --use-interp --interp-temp 10.5 --lmbda 0.3 

# the following output is expected:
# Loss (base 2): 3.9580, Perplexity: 15.54

Arguments:

  • --use-external specifies using external memory.
  • --dstore-filename and indexfile specify the datastore and the Faiss index paths.
  • --use-interp specifies using a linear interpolation between two distributions to calibrate final probablity.
  • --lmbda and --interp-temp specify the temerpature term and the weight when using the linear interpolation.

Performance of pre-trained models

We list the performance of the released pre-trained models on Wikitext-103 and enwik8, as well as their download links.

Dataset Model Dev Test Hyper-parameters
Wikitext-103 TrimeLM
(247M, L=3072)
17.10 17.76 --softmax-temp 1.17
Wikitext-103 TrimeLM_long
(247M, L=3072)
17.01 17.64 --softmax-temp 1.22 --mem-size 12288
Wikitext-103 TrimeLM_ext
(247M, L=3072)
15.54 15.46 --softmax-temp 1.25 --mem-size 12288 --interp-temp 10.5 --lmbda 0.3
Wikitext-103 TrimeLM
(150M, L=150)
24.45 25.61 --softmax-temp 1.03
Wikitext-103 TrimeLM_long
(150M, L=150)
21.76 22.62 --softmax-temp 1.07 --mem-size 15000
enwik8 TrimeLM
(38M, L=512)
1.14 1.12 --softmax-temp 1.05
enwik8 TrimeLM_long
(38M, L=512)
1.08 1.05 --softmax-temp 1.10 --mem-size 24576

Train TrimeLM

Trime loss functions

We follow Fairseq's training recipe (e.g., optimizer, learning rate, batch size) to train TrimeLM. Differently, we use our own loss functions (specified by --criterion) and data batching methods.

We trained three varieties of TrimeLM by using different data batching and memory construction methods.

  • TrimeLM is trained with --criterion trime_loss
    • During training, we use previous tokens in the same segment to construct the working memory.
  • TrimeLM_long is trained with either --criterion trime_long_loss or --criterion trime_long_loss_same_device
    • We batch consecutive segments into one mini-batch; argument --keep-order is needed to batch consecutive segments.
    • During training, we use all the tokens int previous segments and previous tokens in the same segment to construct the working memory.
    • When using trime_long_loss, we need to specify the memory size through --train-mem-size (num. of consecutive segments will be args.train_mem_size/args.tokens_per_sample).
    • When using trime_long_loss_same_device, we assume all consecutive segments are loaded in the same GPU device (equivalently args.mem_size == args.max_tokens). Using trime_long_loss_same_device is more efficient than using trime_long_loss, as it requires less cross-GPU communications.
  • TrimeLM_ext is trained with --criterion trime_ext_loss
    • We batch segments that have high BM25 scores into one mini-batch. The results of BM25 batching is specified by --predefined-batches.
    • During training, we use all previous tokens in the same segment and all tokens in other segments to construct working memory.
    • With a probability p, we disable the local memory (i.e., only using tokens from other segments to construct memory). The probablity p is specified by --cross-sent-ratio

Training scripts

Here is an example of training a TrimeLM_ext model. You can find all training scripts we used in our experiments in train_scripts.

We train our models on 4 NVIDIA RTX3090 GPUs.

# download the results of bm25 batching
wget https://nlp.cs.princeton.edu/projects/trime/bm25_batch/wiki103-l3072-batches.json -P data-bin/wikitext-103/

python train.py --task language_modeling data-bin/wikitext-103 \
    --save-dir output/wiki103-247M-trime_ext \
    --arch transformer_lm_wiki103 \
    --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
    --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
    --criterion trime_ext_loss --max-tokens 3072 --update-freq 6 --tokens-per-sample 3072 --seed 1 \
    --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --knn-keytype last_ffn_input --fp16 \
    --ce-warmup-epoch 9 --cross-sent-ratio 0.9 \
    --predefined-batches data-bin/wikitext-103/wiki103-l3072-batches.json

Important arguments:

  • --arch specifies the model architecture. In our experiments, we have been using the following architectures.
    • transformer_lm_wiki103 (a 247M model for wikitext-103)
    • transformer_lm_wiki103_150M (a 150M model for wikitext-103)
    • transformer_lm_enwik8 (a 38M model for enwik8)
  • --criterion specifies the function to compute loss values. See description above about which functions we support.
  • --tokens-per-sample specifies the segment length.
  • --max-tokens specifies the number of tokens to be loaded in each GPU.
  • --update-freq specifies the gradient-accumulation steps.
  • --ce-warmup-epoch specifies how many epoches the original CE loss is used at the beginning to warm-up the training.
  • --cross-sent-ratio specifies the probability p to disable the local memory.
  • --predefined-batches specifies the file path of the predefined batches (we use BM25 to batch segments).

BM25 batching

When training the TrimeLM_ext model with --criterion trime_ext_loss, we use BM25 scores to batch training data.

We use the Pyserini library to build BM25 index. The library can be installed via pip.

pip install pyserini

We first save all the segments from the training set into a .json file.

mkdir -p bm25/wiki103-l3072/segments
CUDA_VISIBLE_DEVICES=0 python train.py --task language_modeling \
    data-bin/wikitext-103 \
    --max-tokens 6144 --tokens-per-sample 3072 \
    --arch transformer_lm_wiki103 \
    --output-segments-to-file bm25/wiki103-l3072/segments/segments.json

# Modify --tokens-per-sample for different segment lengths

Then, we build the BM25 index using Pyserini.

python -m pyserini.index.lucene \
  --collection JsonCollection \
  --input bm25/wiki103-l3072/segments \
  --index bm25/wiki103-l3072/bm25_index \
  --generator DefaultLuceneDocumentGenerator --threads 1 \
  --storePositions --storeDocvectors --storeRaw

Next, for each training segment, we search the similar segments using the BM25 index we built above.

python bm25_search.py \
    --index_path bm25/wiki103-l3072/bm25_index/ \
    --segments_path bm25/wiki103-l3072/segments/segments.json \
    --results_path bm25/wiki103-l3072/bm25_results

# Use --num_shards and --shard_id; you can parallel the computation of NN search (e.g., --num_shards 20).

Finally, based on the retrieval results, we create batches by group similar segments.

python bm25_make_batches.py \
    --results_path bm25/wiki103-l3072/bm25_results \
    --batch_file data-bin/wikitext-103/wiki103-l3072-batches.json

The output file wiki103-l3072-batches.json contains a list of indices of training segments and adjacent segments are likely to be similar.

The batch file wiki103-l3072-batches.json can be used during the training of TrimeLM_ext, with the argument --predefined-batches. During training, we simply get training batches by taking sub-lists sequencitally from the file.

Machine Translation

For machine translation code and experiments, please check out the subdirectory.

Bugs or Questions?

If you have any questions related to the code or the paper, or you encounter any problems when using the code, feel free to email Zexuan Zhong ([email protected]) or open an issue. Please try to specify the problem with details so we can help you better and quicker!

Citation

If you use our code in your research, please cite our work:

@inproceedings{zhong2022training,
   title={Training Language Models with Memory Augmentation},
   author={Zhong, Zexuan and Lei, Tao and Chen, Danqi},
   booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
   year={2022}
}

Acknowledgement

Our repo is based on the fairseq, knnlm, and adaptive-knn-mt projects. We thank the authors for open-sourcing the great code!

More Repositories

1

SWE-agent

SWE-agent takes a GitHub issue and tries to automatically fix it, using GPT-4, or your LM of choice. It solves 12.47% of bugs in the SWE-bench evaluation set and takes just 1 minute to run.
Python
12,189
star
2

tree-of-thought-llm

[NeurIPS 2023] Tree of Thoughts: Deliberate Problem Solving with Large Language Models
Python
4,416
star
3

SimCSE

[EMNLP 2021] SimCSE: Simple Contrastive Learning of Sentence Embeddings https://arxiv.org/abs/2104.08821
Python
3,310
star
4

SWE-bench

[ICLR 2024] SWE-Bench: Can Language Models Resolve Real-world Github Issues?
Python
1,554
star
5

MeZO

[NeurIPS 2023] MeZO: Fine-Tuning Language Models with Just Forward Passes. https://arxiv.org/abs/2305.17333
Python
1,002
star
6

PURE

[NAACL 2021] A Frustratingly Easy Approach for Entity and Relation Extraction https://arxiv.org/abs/2010.12812
Python
777
star
7

LM-BFF

[ACL 2021] LM-BFF: Better Few-shot Fine-tuning of Language Models https://arxiv.org/abs/2012.15723
Python
712
star
8

DensePhrases

[ACL 2021] Learning Dense Representations of Phrases at Scale; EMNLP'2021: Phrase Retrieval Learns Passage Retrieval, Too https://arxiv.org/abs/2012.12624
Python
601
star
9

SimPO

SimPO: Simple Preference Optimization with a Reference-Free Reward
Python
510
star
10

LLM-Shearing

[ICLR 2024] Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning
Python
492
star
11

ALCE

[EMNLP 2023] Enabling Large Language Models to Generate Text with Citations. Paper: https://arxiv.org/abs/2305.14627
Python
414
star
12

LESS

[ICML 2024] LESS: Selecting Influential Data for Targeted Instruction Tuning
Jupyter Notebook
319
star
13

AutoCompressors

[EMNLP 2023] Adapting Language Models to Compress Long Contexts
Python
262
star
14

WebShop

[NeurIPS 2022] ๐Ÿ›’WebShop: Towards Scalable Real-World Web Interaction with Grounded Language Agents
Python
247
star
15

CoFiPruning

[ACL 2022] Structured Pruning Learns Compact and Accurate Models https://arxiv.org/abs/2204.00408
Python
187
star
16

intercode

[NeurIPS 2023 D&B] Code repository for InterCode benchmark https://arxiv.org/abs/2306.14898
Python
179
star
17

OptiPrompt

[NAACL 2021] Factual Probing Is [MASK]: Learning vs. Learning to Recall https://arxiv.org/abs/2104.05240
Python
167
star
18

TransformerPrograms

[NeurIPS 2023] Learning Transformer Programs
Python
154
star
19

EntityQuestions

EMNLP'2021: Simple Entity-centric Questions Challenge Dense Retrievers https://arxiv.org/abs/2109.08535
Python
134
star
20

QuRating

[ICML 2024] Selecting High-Quality Data for Training Language Models
Python
119
star
21

CEPE

[ACL 2024] Long-Context Language Modeling with Parallel Encodings
Python
117
star
22

DinkyTrain

Princeton NLP's pre-training library based on fairseq with DeepSpeed kernel integration ๐Ÿšƒ
Python
109
star
23

LLMBar

[ICLR 2024] Evaluating Large Language Models at Evaluating Instruction Following
Python
95
star
24

MQuAKE

[EMNLP 2023] MQuAKE: Assessing Knowledge Editing in Language Models via Multi-Hop Questions
Jupyter Notebook
86
star
25

USACO

Can Language Models Solve Olympiad Programming?
Python
86
star
26

NLProofS

EMNLP 2022: Generating Natural Language Proofs with Verifier-Guided Search https://arxiv.org/abs/2205.12443
Python
80
star
27

MADE

EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering
Python
70
star
28

LM-Kernel-FT

A Kernel-Based View of Language Model Fine-Tuning https://arxiv.org/abs/2210.05643
Python
68
star
29

calm-textgame

[EMNLP 2020] Keep CALM and Explore: Language Models for Action Generation in Text-based Games
Python
64
star
30

CharXiv

CharXiv: Charting Gaps in Realistic Chart Understanding in Multimodal LLMs
Python
63
star
31

c-sts

[EMNLP 2023] C-STS: Conditional Semantic Textual Similarity
Python
61
star
32

DataMUX

[NeurIPS 2022] DataMUX: Data Multiplexing for Neural Networks
Jupyter Notebook
58
star
33

ShortcutGrammar

EMNLP 2022: Finding Dataset Shortcuts with Grammar Induction https://arxiv.org/abs/2210.11560
Jupyter Notebook
58
star
34

LitSearch

A Retrieval Benchmark for Scientific Literature Search
Python
53
star
35

Collie

[ICLR 2024] COLLIE: Systematic Construction of Constrained Text Generation Tasks
Jupyter Notebook
51
star
36

EvalConvQA

[ACL 2022] Ditch the Gold Standard: Re-evaluating Conversational Question Answering
Python
45
star
37

MABEL

EMNLP 2022: "MABEL: Attenuating Gender Bias using Textual Entailment Data" https://arxiv.org/abs/2210.14975
Python
35
star
38

LM-Science-Tutor

Python
32
star
39

rationale-robustness

NAACL 2022: Can Rationalization Improve Robustness? https://arxiv.org/abs/2204.11790
Python
26
star
40

PTP

Improving Language Understanding from Screenshots. Paper: https://arxiv.org/abs/2402.14073
Python
23
star
41

InstructEval

[NAACL 2024 Findings] Evaluation suite for the systematic evaluation of instruction selection methods.
Jupyter Notebook
23
star
42

WhatICLLearns

[ACL 2023 Findings] What In-Context Learning โ€œLearnsโ€ In-Context: Disentangling Task Recognition and Task Learning
Python
21
star
43

Cognac

Repo for paper: Controllable Text Generation with Language Constraints
Python
19
star
44

corpus-poisoning

[EMNLP 2023] Poisoning Retrieval Corpora by Injecting Adversarial Passages https://arxiv.org/abs/2310.19156
Python
18
star
45

semsup

Semantic Supervision: Enabling Generalization over Output Spaces
Python
16
star
46

ELIZA-Transformer

Representing Rule-based Chatbots with Transformers
Python
15
star
47

SRL-NLC

Safe Reinforcement Learning with Natural Language Constraints
14
star
48

Edge-Pruning

Code and data for the paper "Finding Transformer Circuits with Edge Pruning".
Python
14
star
49

datamux-pretraining

MUX-PLMs: Pretraining LMs with Data Multiplexing
Python
14
star
50

XTX

[ICLR 2022 Spotlight] Multi-Stage Episodic Control for Strategic Exploration in Text Games
Python
13
star
51

MultilingualAnalysis

Repository for the paper titled: "When is BERT Multilingual? Isolating Crucial Ingredients for Cross-lingual Transfer"
Python
13
star
52

blindfold-textgame

[NAACL 2021] Reading and Acting while Blindfolded: The Need for Semantics in Text Game Agents
Python
12
star
53

align-mlm

Python
11
star
54

dyck-transformer

[ACL 2021] Self-Attention Networks Can Process Bounded Hierarchical Languages
Python
11
star
55

metric-wsd

NAACL'2021: Non-Parametric Few-Shot Learning for Word Sense Disambiguation
Python
10
star
56

semsup-xc

SemSup-XC: Semantic Supervision for Extreme Classification
Jupyter Notebook
10
star
57

lwm

We develop world models that can be adapted with natural language. Intergrating these models into artificial agents allows humans to effectively control these agents through verbal communication.
Python
9
star
58

benign-data-breaks-safety

Python
7
star
59

CopyCat

Python
7
star
60

Heuristic-Core

[ACL 2024] The Heuristic Core: Understanding Subnetwork Generalization in Pretrained Language Models - https://arxiv.org/abs/2403.03942
Python
6
star
61

CARETS

Python
6
star
62

SPARTAN

SPARTAN: Sparse Hierarchical Memory for Parameter-Efficient Transformers
Python
5
star
63

attribute-tagging

[LaReL 2022] Towards an Enhanced, Faithful, and Adaptable Web Interaction Environment
Python
4
star
64

NegotiationToM

Code release for Improving Dialog Systems for Negotiation with Personality Modeling.
Python
4
star
65

il-scaling-in-games

Official code repo of "Scaling Laws for Imitation Learning in NetHack"
Python
4
star
66

MoQA

Python
3
star