TRIME: Training Language Models with Memory Augmentation
This is the repository for our EMNLP2022 paper Training Language Models with Memory Augmentation, by Zexuan Zhong, Tao Lei, and Danqi Chen.
Updates
- [2022/11/13] We have released the code and pre-trained models for the machine translation experiments here.
- [2022/10/25] Our paper has been accepted to EMNLP 2022! Please check out the updated version. We have added domain adaptation and stronger results on machine translation and character-level language modeling.
- [2022/07/31] We have released our training code and pre-trained models.
- [2022/05/24] We have released the preprint of our TRIME paper on training LMs with memory augmentation.
Quick links
- Overview
- Setup
- Run Pre-Trained Models
- Train TrimeLM
- Machine Translation
- Bugs or Questions?
- Citation
- Acknowledgement
Overview
We propose a new training objective TRIME for language modeling, which aligns model outputs with both token embeddings and in-batch memories. We also devise novel ways for data batching and constructing training memories, so that our models can leverage long-range contexts and external datastore effectively.
Please find more details of this work in our paper.
Setup
Requirements and dependencies
The code is based on the following requirements/dependencies (we specify the version we used in our experiments in brackets):
- Python (3.7.11)
- pytorch (1.9.1+cu111)
- faiss-gpu (1.7.1)
- numpy (1.21.2)
You can install this project (based on Fairseq) as follow:
pip install --editable .
Datasets
We conduct experiments on the Wikitext-103 and enwik8 datasets. Please use get_data.sh
to download and preprocess the datasets.
bash get_data.sh {wikitext-103 | enwik8}
The processed datasets will be stored in data-bin/wikitext-103
and data-bin/enwik8
.
Run Pre-Trained Models
We show the examples of running pre-trained models on Wikitext-103 with model size = 247M and segment length = 3072. For other experiments (e.g., with different datasets or models), we refer to run_pretrained_models.md for the scripts on all experimental settings.
TrimeLM (local memory)
TrimeLM uses only the local memory (constructed using tokens in the input). It can be viewed as a lightweight replacement for vanilla langauge models.
# download the pre-trained TrimeLM
mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime.zip;
unzip wiki103-247M-trime.zip; rm -f wiki103-247M-trime.zip
cd ..
# run evaluation
python eval_lm-trime.py data-bin/wikitext-103 \
--path pretrained_models/wiki103-247M-trime/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 \
--softmax-batch 1024 --gen-subset valid --fp16 \
--max-sentences 1 --knn-keytype last_ffn_input \
--use-local --softmax-temp 1.17
# the following output is expected:
# Loss (base 2): 4.0962, Perplexity: 17.10
Arguments:
--use-local
specifies using local memory.--softmax-temp
specifies the temperature term used when computing the loss.
TrimeLM_long (local + long-term memory)
TrimeLM_long uses local memory and long-term memory during inference. The model is able to leverage long contexts, although it is trained with shorter ones.
# download the pre-trained TRIME_long
mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime_long.zip;
unzip wiki103-247M-trime_long.zip; rm -f wiki103-247M-trime_long.zip
cd ..
# run evaluation
python eval_lm-trime.py data-bin/wikitext-103 \
--path pretrained_models/wiki103-247M-trime_long/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 \
--softmax-batch 1024 --gen-subset valid --fp16 \
--max-sentences 1 --knn-keytype last_ffn_input \
--use-local --use-long --mem-size 12288 --softmax-temp 1.22
# the following output is expected:
# Loss (base 2): 4.0879, Perplexity: 17.01
Arguments:
--use-long
specifies using long-term memory.--mem-size
specifies the size of local + long-term memory.
TrimeLM_ext (local + long-term + external memory)
TrimeLM_ext uses local memory, long-term memory, and external memory. During inference, we run the model on the training set to build the external memory and use Faiss library to build index for retrieving top-K nearest neighbors the external memory. We also calibrate a separated distribution over the memory and interpolate the output distribution and the memory distribution, similarly to kNN-LM (see details in the paper).
We first download the pre-trained TrimeLM_ext:
mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime_ext.zip;
unzip wiki103-247M-trime_ext.zip; rm -f wiki103-247M-trime_ext.zip
cd ..
Then, we generate the external memory (keys and values) using the training set and then build the Faiss index:
MODEL_PATH=pretrained_models/wiki103-247M-trime_ext
# generate the external memory (keys and values) using the training set
python eval_lm.py data-bin/wikitext-103 \
--path ${MODEL_PATH}/checkpoint_best.pt \
--sample-break-mode none --max-tokens 3072 \
--softmax-batch 1024 --gen-subset train \
--context-window 2560 --tokens-per-sample 512 \
--dstore-mmap ${MODEL_PATH}/dstore --knn-keytype last_ffn_input \
--dstore-size 103224461 \
--save-knnlm-dstore --fp16 --dstore-fp16
# build Faiss index
python build_dstore.py \
--dstore_mmap ${MODEL_PATH}/dstore \
--dstore_size 103224461 --dimension 1024 \
--faiss_index ${MODEL_PATH}/knn.index \
--num_keys_to_add_at_a_time 500000 \
--starting_point 0 --dstore_fp16 --dist ip
Now, we are ready to evaluate the model:
MODEL_PATH=pretrained_models/wiki103-247M-trime_ext
python eval_lm-trime.py data-bin/wikitext-103 \
--path ${MODEL_PATH}/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 \
--softmax-batch 1024 --gen-subset valid --fp16 \
--max-sentences 1 --knn-keytype last_ffn_input \
--use-local --use-long --mem-size 12288 --softmax-temp 1.25 \
--use-external --dstore-filename ${MODEL_PATH}/dstore --indexfile ${MODEL_PATH}/knn.index.ip \
--probe 32 --dstore-fp16 --faiss-metric-type ip --no-load-keys --k 1024 \
--use-interp --interp-temp 10.5 --lmbda 0.3
# the following output is expected:
# Loss (base 2): 3.9580, Perplexity: 15.54
Arguments:
--use-external
specifies using external memory.--dstore-filename
andindexfile
specify the datastore and the Faiss index paths.--use-interp
specifies using a linear interpolation between two distributions to calibrate final probablity.--lmbda
and--interp-temp
specify the temerpature term and the weight when using the linear interpolation.
Performance of pre-trained models
We list the performance of the released pre-trained models on Wikitext-103 and enwik8, as well as their download links.
Dataset | Model | Dev | Test | Hyper-parameters |
---|---|---|---|---|
Wikitext-103 | TrimeLM (247M, L=3072) |
17.10 | 17.76 | --softmax-temp 1.17 |
Wikitext-103 | TrimeLM_long (247M, L=3072) |
17.01 | 17.64 | --softmax-temp 1.22 --mem-size 12288 |
Wikitext-103 | TrimeLM_ext (247M, L=3072) |
15.54 | 15.46 | --softmax-temp 1.25 --mem-size 12288 --interp-temp 10.5 --lmbda 0.3 |
Wikitext-103 | TrimeLM (150M, L=150) |
24.45 | 25.61 | --softmax-temp 1.03 |
Wikitext-103 | TrimeLM_long (150M, L=150) |
21.76 | 22.62 | --softmax-temp 1.07 --mem-size 15000 |
enwik8 | TrimeLM (38M, L=512) |
1.14 | 1.12 | --softmax-temp 1.05 |
enwik8 | TrimeLM_long (38M, L=512) |
1.08 | 1.05 | --softmax-temp 1.10 --mem-size 24576 |
Train TrimeLM
Trime loss functions
We follow Fairseq's training recipe (e.g., optimizer, learning rate, batch size) to train TrimeLM. Differently, we use our own loss functions (specified by --criterion
) and data batching methods.
We trained three varieties of TrimeLM by using different data batching and memory construction methods.
- TrimeLM is trained with
--criterion trime_loss
- During training, we use previous tokens in the same segment to construct the working memory.
- TrimeLM_long is trained with either
--criterion trime_long_loss
or--criterion trime_long_loss_same_device
- We batch consecutive segments into one mini-batch; argument
--keep-order
is needed to batch consecutive segments. - During training, we use all the tokens int previous segments and previous tokens in the same segment to construct the working memory.
- When using
trime_long_loss
, we need to specify the memory size through--train-mem-size
(num. of consecutive segments will beargs.train_mem_size/args.tokens_per_sample
). - When using
trime_long_loss_same_device
, we assume all consecutive segments are loaded in the same GPU device (equivalentlyargs.mem_size == args.max_tokens
). Usingtrime_long_loss_same_device
is more efficient than usingtrime_long_loss
, as it requires less cross-GPU communications.
- We batch consecutive segments into one mini-batch; argument
- TrimeLM_ext is trained with
--criterion trime_ext_loss
- We batch segments that have high BM25 scores into one mini-batch. The results of BM25 batching is specified by
--predefined-batches
. - During training, we use all previous tokens in the same segment and all tokens in other segments to construct working memory.
- With a probability
p
, we disable the local memory (i.e., only using tokens from other segments to construct memory). The probablityp
is specified by--cross-sent-ratio
- We batch segments that have high BM25 scores into one mini-batch. The results of BM25 batching is specified by
Training scripts
Here is an example of training a TrimeLM_ext model. You can find all training scripts we used in our experiments in train_scripts.
We train our models on 4 NVIDIA RTX3090 GPUs.
# download the results of bm25 batching
wget https://nlp.cs.princeton.edu/projects/trime/bm25_batch/wiki103-l3072-batches.json -P data-bin/wikitext-103/
python train.py --task language_modeling data-bin/wikitext-103 \
--save-dir output/wiki103-247M-trime_ext \
--arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion trime_ext_loss --max-tokens 3072 --update-freq 6 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --knn-keytype last_ffn_input --fp16 \
--ce-warmup-epoch 9 --cross-sent-ratio 0.9 \
--predefined-batches data-bin/wikitext-103/wiki103-l3072-batches.json
Important arguments:
--arch
specifies the model architecture. In our experiments, we have been using the following architectures.transformer_lm_wiki103
(a 247M model for wikitext-103)transformer_lm_wiki103_150M
(a 150M model for wikitext-103)transformer_lm_enwik8
(a 38M model for enwik8)
--criterion
specifies the function to compute loss values. See description above about which functions we support.--tokens-per-sample
specifies the segment length.--max-tokens
specifies the number of tokens to be loaded in each GPU.--update-freq
specifies the gradient-accumulation steps.--ce-warmup-epoch
specifies how many epoches the original CE loss is used at the beginning to warm-up the training.--cross-sent-ratio
specifies the probabilityp
to disable the local memory.--predefined-batches
specifies the file path of the predefined batches (we use BM25 to batch segments).
BM25 batching
When training the TrimeLM_ext model with --criterion trime_ext_loss
, we use BM25 scores to batch training data.
We use the Pyserini library to build BM25 index. The library can be installed via pip.
pip install pyserini
We first save all the segments from the training set into a .json
file.
mkdir -p bm25/wiki103-l3072/segments
CUDA_VISIBLE_DEVICES=0 python train.py --task language_modeling \
data-bin/wikitext-103 \
--max-tokens 6144 --tokens-per-sample 3072 \
--arch transformer_lm_wiki103 \
--output-segments-to-file bm25/wiki103-l3072/segments/segments.json
# Modify --tokens-per-sample for different segment lengths
Then, we build the BM25 index using Pyserini.
python -m pyserini.index.lucene \
--collection JsonCollection \
--input bm25/wiki103-l3072/segments \
--index bm25/wiki103-l3072/bm25_index \
--generator DefaultLuceneDocumentGenerator --threads 1 \
--storePositions --storeDocvectors --storeRaw
Next, for each training segment, we search the similar segments using the BM25 index we built above.
python bm25_search.py \
--index_path bm25/wiki103-l3072/bm25_index/ \
--segments_path bm25/wiki103-l3072/segments/segments.json \
--results_path bm25/wiki103-l3072/bm25_results
# Use --num_shards and --shard_id; you can parallel the computation of NN search (e.g., --num_shards 20).
Finally, based on the retrieval results, we create batches by group similar segments.
python bm25_make_batches.py \
--results_path bm25/wiki103-l3072/bm25_results \
--batch_file data-bin/wikitext-103/wiki103-l3072-batches.json
The output file wiki103-l3072-batches.json
contains a list of indices of training segments and adjacent segments are likely to be similar.
The batch file wiki103-l3072-batches.json
can be used during the training of TrimeLM_ext, with the argument --predefined-batches
.
During training, we simply get training batches by taking sub-lists sequencitally from the file.
Machine Translation
For machine translation code and experiments, please check out the subdirectory.
Bugs or Questions?
If you have any questions related to the code or the paper, or you encounter any problems when using the code, feel free to email Zexuan Zhong ([email protected])
or open an issue. Please try to specify the problem with details so we can help you better and quicker!
Citation
If you use our code in your research, please cite our work:
@inproceedings{zhong2022training,
title={Training Language Models with Memory Augmentation},
author={Zhong, Zexuan and Lei, Tao and Chen, Danqi},
booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
year={2022}
}
Acknowledgement
Our repo is based on the fairseq, knnlm, and adaptive-knn-mt projects. We thank the authors for open-sourcing the great code!