• Stars
    star
    249
  • Rank 162,987 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated about 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch + HuggingFace code for RetoMaton: "Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval" (ICML 2022), including an implementation of kNN-LM and kNN-MT

kNN-transformers: Nearest-Neighbor Language and Machine Translation Models based on Hugging Face's 🤗 transformers library

This is a Hugging Face's 🤗 transformers implementation of k-nearest-neighbor-based language models and machine translation models, designed to be easy and useful in research, and for experimenting with new ideas in kNN-based models.

All previous kNN-LM based implementations are implemented in the fairseq library, and they forked/duplicated the library's entire codebase to implement their modification. These include the official kNN-LM repository https://github.com/urvashik/knnlm, the kNN-MT repository https://github.com/urvashik/knnmt, the official RetoMaton repository https://github.com/neulab/retomaton, and others.

We implement k-nearest-neighbor language model (kNN-LM) (Khandelwal et al., ICLR'2020), k-nearest-neighbor machine translation (kNN-MT) (Khandelwal et al., ICLR'2021) and this is also an official implementation of the RetoMaton model described in: Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval (ICML'2022). Most importantly, we implement these models in 🤗 transformers, and without modifying the transformers library itself.

To use this repository, all you need to do is copy its *.py files into your project. You can load any language model from Hugging Face's hub (such as gpt2), by model = AutoModelForCausalLM.from_pretrained(...), build a datastore or download ours (need to be performed only once), and then:

knn_wrapper = KNNWrapper(...) # or: RetomatonWrapper(...)
knn_wrapper.break_into(model)

That's it! The model now internally uses kNN-LM or RetoMaton (see a concrete example at run_clm.py)

The files knnlm.py and retomaton.py are standalone and can be copied to any project. The file run_clm.py is a modified version of this example by huggingface which shows an example of how to load and run kNN-LM and RetoMaton. The file run_translation.py is a modified version of this translation example by huggingface, which shows how to use our code for kNN-MT and RetoMaton.

This repository is maintained by Uri Alon. Please let us know if anything is not working as expected, and feel free to create new issues or email [email protected] with any questions. Contributions are welcome!

Table of Contents

Background

kNN-LM and kNN-MT

The k-nearest neighbor language model takes an already-trained model, performs a single forward pass over the entire training set, and creates a datastore of (key,value) pairs, where key is a hidden representation of the trained model after reading a training example, and value is the token that should be predicted next.

At test time, for every predicted token, the model performs a k-nearest neighbors search in the datastore, retrieves the (key,value) pairs that are closest to the test hidden representation, and normalizes their distances using softmax. Finally, the model interpolates the base LM's probability with the probability formed by the retrieved nearest neighbors and their normalized distances. For more details, see the paper by Khandelwal et al., ICLR'2020

kNN-MT is similar to kNN-LM, but it addresses seq2seq/encoder-decoder models, mainly for machine translation. For more details, see the paper by Khandelwal et al., ICLR'2021

RetoMaton

RetoMaton extends kNN-LM and kNN-MT, by (1) saving a pointer in every datastore entry; and (2) clustering entries according to their keys. That is, every datastore entry is now a tuple (key,value,pointer), and it belongs to a cluster.

These two changes create an automaton from the datastore, where states are clusters, edges are pointers (shared among examples in the same cluster), and transition weights are the normalized distances between the test representation and each key.

At test time, the model traverses the automaton, and follows the edges according to the token that was predicted. This allows to save up to 80% of the kNN searches by following pointers instead of performing the expensive search, or reducing perplexity without saving searches.

For more details, see the paper by Alon et al., ICML'2022

Available Models

kNN-LM and RetoMaton datastores depend on the LM that was used to create them. We fine-tuned a few gpt2-based models on the training set of Wikitext-103 (because Wikitext-103 was not included in GPT2's pretraining data):

  • neulab/distilgpt2-finetuned-wikitext103
  • neulab/gpt2-finetuned-wikitext103
  • neulab/gpt2-med-finetuned-wikitext103
  • neulab/gpt2-large-finetuned-wikitext103

All these models are available at the Hugging Face Hub and can be loaded by (for example):

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('neulab/gpt2-finetuned-wikitext103')
model = AutoModelForCausalLM.from_pretrained('neulab/gpt2-finetuned-wikitext103')

This project is not limited to these models, and can work with any language model or seq2seq model.

We fine-tuned all language models using:

python run_clm.py --model_name_or_path <base_model_name> \
    --dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
    --do_train --do_eval --output_dir finetune_gpt2_wikitext103/ \
    --save_total_limit 5 --per_device_train_batch_size 2

Where <base_model_name> is, for example, gpt2, distilgpt2, gpt2-med, gpt2-large, or gpt2-xl.

We have not yet released finetuned machine translation models, but the code in this repository works for machine translation as well, using the run_translation.py script.

Results - Wikitext-103

The exact results from the RetoMaton papers can be reproduced using the code at https://github.com/neulab/retomaton (based on fairseq).

The following results were obtained using the code in this repository:

Base LM: distilgpt2 gpt2 gpt2-medium
base perplexity 18.25 14.84 11.55
kNN-LM 15.03 12.57 10.59
RetoMaton 14.71 12.44 10.59

And when varying the fraction of saved searches:

These are the results from the RetoMaton paper, on a model that was trained on Wikitext-103 from scratch:

Results - Translation

On the validation set of --dataset_name wmt16 --dataset_config_name ro-en.

Base model: t5-small t5-base
base BLEU 26.15 27.70
+ kNN-MT 26.42 27.92
--knn_temp=50 --k=32 --lmbda=0.25 --knn_temp=200 --k=512 --lmbda=0.2

If you perform additional experiments with our code, we would love to learn more about your results and share them here!

Quickstart - Language Modeling

Step 0: Clone this repository:

git clone https://github.com/neulab/knn-transformers
cd knn-transformers

Requirements

Run:

pip install requirements.txt`
  • The project also depends on the faiss library. In MacOS, use the Anaconda installation instead:
conda install -c conda-forge faiss-cpu

Step 1: Evaluating the base Language Model

To evaluate the fine-tuned model (for example, neulab/gpt2-finetuned-wikitext103) on the validation set (without any retrieval):

MODEL=neulab/gpt2-finetuned-wikitext103

python -u run_clm.py \
  --model_name_or_path ${MODEL} \
  --dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
  --output_dir checkpoints/${MODEL} \
  --do_eval --eval_subset validation

Step 2: Saving a Datastore

You can either download our preprocessed Wikitext-103 datastores, or preprocess them yourself.

To download a datastore for Wikitext-103 that we created for the finetuned gpt2 model (neulab/gpt2-finetuned-wikitext103):

wget -P checkpoints/gpt2/ https://knn-transformers.s3.amazonaws.com/gpt2/dstore_gpt2_116988150_768_vals.npy

Similarly, we created datastores using the distilgpt2-finetuned-wikitext103, gpt2-med-finetuned-wikitext103 and gpt2-large-finetuned-wikitext103. For all available datastores, see: https://knn-transformers.s3.amazonaws.com/index.html

To save a datastore, run:

MODEL=neulab/gpt2-finetuned-wikitext103

python -u run_clm.py \
  --model_name_or_path ${MODEL} \
  --dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
  --do_eval --eval_subset train \
  --output_dir checkpoints/${MODEL} \
  --dstore_dir checkpoints/${MODEL} \
  --save_knnlm_dstore

Step 3: Building the FAISS index

The FAISS index requires a training phase where it learns an index for accessing the keys quickly. This step does not require a GPU.

To download an index for the finetuned gpt2 model (neulab/gpt2-finetuned-wikitext103):

wget -P checkpoints/gpt2/ https://knn-transformers.s3.amazonaws.com/gpt2/index_gpt2_116988150_768.indexed

Similarly, we trained faiss indexes for distilgpt2-finetuned-wikitext103, gpt2-med-finetuned-wikitext103 and gpt2-large-finetuned-wikitext103, see: https://knn-transformers.s3.amazonaws.com/index.html

To build the FAISS index yourself (not needed if you already downloaded ours):

MODEL=neulab/gpt2-finetuned-wikitext103

python -u run_clm.py \
  --model_name_or_path ${MODEL} \
  --dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
  --output_dir checkpoints/${MODEL} \
  --dstore_dir checkpoints/${MODEL} \
  --build_index

Step 4: Evaluating Models

To evaluate kNN-LM and RetoMaton on the validation set:

MODEL=neulab/gpt2-finetuned-wikitext103

python -u run_clm.py \
  --model_name_or_path ${MODEL} \
  --dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
  --output_dir checkpoints/${MODEL} \
  --do_eval --eval_subset validation \
  --dstore_dir checkpoints/${MODEL} --retomaton

To use kNN-LM, use the --knn flag instead of --retomaton.

To encourage the RetoMaton model to perform a full kNN search more frequently and thus increase accuracy and reduce perplexity, use a larger value of --min-knns such as 100. Using --min-knns 9999999 makes the model perform kNN search at every step (FoSS = 0 in Figure 3 of the paper), and achieves the best results at the cost of slower speed.

Additional possible test-time tunable hyperparameters are --lmbda (the interpolation factor between the datastore and the base LM), --k (the number of retrieved nearest neighbors) and --knn_temp (the softmax temperature when converting the nearest-neighbor distances into a probability distribution).

Step 5: Adding clustering

RetoMaton can work without clusters, in which is utilizes its pointers only. Using clustering allows it to save more nearest-neighbor searches and further reduce perplexity.

To download our processed clusters the finetuned gpt2 model (neulab/gpt2-finetuned-wikitext103):

wget -P checkpoints/gpt2/ https://knn-transformers.s3.amazonaws.com/index.htmlgpt2/members_gpt2_116988150_768_20000000_500000.pkl

Similarly, we also provide clusters for the distilgpt2 model (neulab/distilgpt2-finetuned-wikitext103) at https://knn-transformers.s3.amazonaws.com/index.html.

To cluster similar keys for RetoMaton yourself:

MODEL=neulab/gpt2-finetuned-wikitext103

python -u run_clm.py \
  --model_name_or_path ${MODEL} \
  --dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
  --output_dir checkpoints/${MODEL} \
  --dstore_dir checkpoints/${MODEL}/ \
  --cluster_dstore --num_clusters 500000 --sample_size 20000000

Once the clustering file is saved in the directory pointed to by --dstore_dir, it will automatically be picked up when running evaluation (as in the previous step)

Optional clustering hyperparameters are --num_clusters (typically 1/100 or 1/200 of the datastore size) and --sample_size (ideally as high as possible, but higher values consume more memory and take longer to run).

Machine Translation (kNN-MT)

Using our code for machine translation and kNN-MT is very similar to language modeling, using the file run_translation.py instead of run_clm.py, and following the example instructions from huggingface: https://github.com/huggingface/transformers/tree/main/examples/pytorch/translation.

Importantly, the --knn_temp flag should be used and tuned for kNN-MT. As shown in the kNN-MT paper, the optimal temperature for kNN-MT can be 10 to 100.

The --lmbda interpolation factor is also typically larger in kNN-MT, and can be 0.4-0.8.

Evaluating the base MT model

MODEL=t5-small

python -u run_translation.py  \
  --model_name_or_path ${MODEL} \
  --dataset_name wmt16 --dataset_config_name ro-en \
  --per_device_eval_batch_size=4 \
  --output_dir checkpoints-translation/${MODEL} \
  --source_lang en --target_lang ro \
  --do_eval \
  --predict_with_generate \
  --source_prefix "translate English to Romanian: "

Note that the flag --source_prefix "translate English to Romanian: " is slightly different for every model and task, and may be unneeded for other models, as detailed at https://github.com/huggingface/transformers/tree/main/examples/pytorch/translation.

Saving a datastore for kNN-MT

Examples datastores for t5-small and t5-base on wmt16 en-ro are available at https://knn-transformers.s3.amazonaws.com/index.html.

MODEL=t5-small

python -u run_translation.py  \
  --model_name_or_path ${MODEL} \
  --dataset_name wmt16 --dataset_config_name ro-en \
  --per_device_train_batch_size 4 --per_device_eval_batch_size=4 \
  --output_dir checkpoints-translation/${MODEL} \
  --source_lang en --target_lang ro \
  --dstore_size 26565876 \
  --dstore_dir checkpoints-translation/${MODEL} \
   --save_knnlm_dstore --do_eval --eval_subset train \
   --source_prefix "translate English to Romanian: "

Building the FAISS index for kNN-MT

MODEL=t5-small

python -u run_translation.py  \
  --model_name_or_path ${MODEL} \
  --dataset_name wmt16 --dataset_config_name ro-en \
  --per_device_train_batch_size 4 --per_device_eval_batch_size=4 \
  --output_dir checkpoints-translation/${MODEL} \
  --source_lang en --target_lang ro \
  --dstore_size 26565876 \
  --dstore_dir checkpoints-translation/${MODEL} \
  --build_index

Evaluating kNN-MT

MODEL=t5-small

python -u run_translation.py  \
  --model_name_or_path ${MODEL} \
  --dataset_name wmt16 --dataset_config_name ro-en \
  --per_device_eval_batch_size=4 \
  --output_dir checkpoints-translation/${MODEL} \
  --source_lang en --target_lang ro \
  --do_eval \
  --predict_with_generate \
  --source_prefix "translate English to Romanian: " \
  --dstore_size 26565876 \
  --dstore_dir checkpoints-translation/${MODEL} \
  --knn_temp 50 --k 32 --lmbda 0.25 \
  --retomaton

To use kNN-MT, use the --knn flag instead of --retomaton.

All files:

Datastores and indexes can be downloaded from https://knn-transformers.s3.amazonaws.com/index.html

All fine-tuned models are available on Hugging Face Hub: https://huggingface.co/neulab

Differences from the kNN-LM implementation

  • The original kNN-LM repository uses faiss CPU to perform retrieval. However, we added the flag --knnlm-gpu that allows performing retrieval much faster on the GPU.
  • After each retrieval, the original kNN-LM repository loads the found keys and re-computes the distance from the query to each nearest neighbor. This is much more time consuming, unless loading all the keys (200GB) into memory. We thus use the distances returned by faiss when performing search, or reconstructing the vectors from their index in RetoMaton, without loading the huge keys.npy file into memory.

Citation

If you use our code for research, please cite:

Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval

@inproceedings{alon2022neuro,
  title={Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval},
  author={Alon, Uri and Xu, Frank and He, Junxian and Sengupta, Sudipta and Roth, Dan and Neubig, Graham},
  booktitle={International Conference on Machine Learning},
  pages={468--485},
  year={2022},
  organization={PMLR}
}

This repository also implements: Generalization through Memorization: Nearest Neighbor Language Models

@inproceedings{khandelwal20generalization,
  title={{Generalization through Memorization: Nearest Neighbor Language Models}},
  author={Khandelwal, Urvashi and Levy, Omer and Jurafsky, Dan and Zettlemoyer, Luke and Lewis, Mike},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2020}
}

More Repositories

1

prompt2model

prompt2model - Generate Deployable Models from Natural Language Instructions
Python
1,953
star
2

Text-Summarization-Papers

An Exhaustive Paper List for Text Summarization
HTML
500
star
3

compare-mt

A tool for holistic analysis of language generations systems
Python
450
star
4

nn4nlp-concepts

A repository of concepts related to neural networks for NLP
Python
447
star
5

ExplainaBoard

Interpretable Evaluation for AI Systems
Python
361
star
6

awesome-align

A neural word aligner based on multilingual BERT
Python
319
star
7

BARTScore

BARTScore: Evaluating Generated Text as Text Generation
Python
317
star
8

InterpretEval

Interpretable Evaluation for (Almost) All NLP Tasks
HTML
193
star
9

ReviewAdvisor

Heavy Workload on Reviewing Papers? ReviewAdvisor Helps out
Python
192
star
10

xnmt

eXtensible Neural Machine Translation
Python
185
star
11

gemini-benchmark

Jupyter Notebook
150
star
12

RIPPLe

Code for the paper "Weight Poisoning Attacks on Pre-trained Models" (ACL 2020)
Jupyter Notebook
138
star
13

SpanNER

SpanNER: Named EntityRe-/Recognition as Span Prediction
Python
124
star
14

word-embeddings-for-nmt

Supplementary material for "When and Why Are Pre-trained Word Embeddings Useful for Neural Machine Translation?" at NAACL 2018
Python
119
star
15

guided_summarization

GSum: A General Framework for Guided Neural Abstractive Summarization
Python
112
star
16

code-bert-score

CodeBERTScore: an automatic metric for code generation, based on BERTScore
Jupyter Notebook
111
star
17

external-knowledge-codegen

Code and data for ACL20 paper "Incorporating External Knowledge through Pre-training for Natural Language to Code Generation"
Python
96
star
18

cmu-multinlp

Generalizing Natural Language Analysis through Span-relation Representations
Python
88
star
19

REALSumm

REALSumm: Re-evaluating Evaluation in Text Summarization
Python
71
star
20

langrank

A program to choose transfer languages for cross-lingual learning
Python
66
star
21

retomaton

PyTorch code for the RetoMaton paper: "Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval" (ICML 2022)
Python
60
star
22

dynet-benchmark

Benchmarks for DyNet
Python
56
star
23

newlang-tech

A guide to building language technology in new languages.
56
star
24

ragged

Retrieval Augmented Generation Generalized Evaluation Dataset
Jupyter Notebook
51
star
25

contextual-mt

A repository with the code related to experiments around context-aware machine translation
Python
48
star
26

extreme-adaptation-for-personalized-translation

Code for the paper "Extreme Adaptation for Personalized Neural Machine Translation"
Python
43
star
27

lrlm

Code for the paper "Latent Relation Language Models" at AAAI-20.
Python
41
star
28

incremental_tree_edit

Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21)
Python
40
star
29

wikiasp

Code for WikiAsp: Multi-document aspect-based summarization.
Shell
39
star
30

tranX-plugin

A plugin for code generation in PyCharm/IntelliJ using tranX
Java
35
star
31

neural-lpcfg

The Return of Lexical Dependencies: Neural Lexicalized PCFGs (TACL)
Python
33
star
32

covid19-datashare

A repo for sharing language resources related to the outbreak (in machine readable format)
GLSL
27
star
33

ToM-Language-Acquisition

Code used to run experiments for the ICLR 2023 paper "Computational Language Acquisition with Theory of Mind".
Python
14
star
34

cmulab

CMU Linguistic Annotation Backend
Python
14
star
35

AfricanVoices

Hosts text-to-speech corpus and speech synthesizers for African languages.
Shell
12
star
36

cmu-ner

NER System Developed at CMU
Python
12
star
37

lti-llm-deployment

Python
12
star
38

explainaboard_web

Mustache
8
star
39

KGxBoard

Explainable and Interactive Leaderboard for Evaluation of Knowledge Graph Completion Models
6
star
40

DGT

WNGT 2019, DGT Task.
Python
6
star
41

tranx-study

HTML
5
star
42

Reliable-NLPPP

Jupyter Notebook
5
star
43

cord19

cord19 related stuff
Python
5
star
44

globalbench

GlobalBench: A Benchmark for Global Progress in Language Technology
Python
5
star
45

jsalt2019-informal

A repository for random things from the JSALT informal translation group
Python
5
star
46

cmu-edl

Python
3
star
47

code-mining

Stuff for code mining
OpenEdge ABL
2
star
48

ocr-web-interface

OCR web interface using CMULAB backend
JavaScript
1
star
49

explainaboard_client

Python
1
star