Lite Transformer with Long-Short Range Attention
@inproceedings{Wu2020LiteTransformer,
title={Lite Transformer with Long-Short Range Attention},
author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han},
booktitle={International Conference on Learning Representations (ICLR)},
year={2020}
}
Overview
We release the PyTorch code for the Lite Transformer. [Paper|Website|Slides]:
Consistent Improvement by Tradeoff Curves
Save 20000x Searching Cost of Evolved Transformer
Further Compress Transformer by 18.2x
How to Use
Prerequisite
- Python version >= 3.6
- PyTorch version >= 1.0.0
- configargparse >= 0.14
- For training new models, you'll also need an NVIDIA GPU and NCCL
Installation
-
Codebase
To install fairseq from source and develop locally:
pip install --editable .
-
Costumized Modules
We also need to build the
lightconv
anddynamicconv
for GPU support.Lightconv_layer
cd fairseq/modules/lightconv_layer python cuda_function_gen.py python setup.py install
Dynamicconv_layer
cd fairseq/modules/dynamicconv_layer python cuda_function_gen.py python setup.py install
Data Preparation
IWSLT'14 De-En
We follow the data preparation in fairseq. To download and preprocess the data, one can run
bash configs/iwslt14.de-en/prepare.sh
WMT'14 En-Fr
We follow the data pre-processing in fairseq. To download and preprocess the data, one can run
bash configs/wmt14.en-fr/prepare.sh
WMT'16 En-De
We follow the data pre-processing in fairseq. One should first download the preprocessed data from the Google Drive provided by Google. To binarized the data, one can run
bash configs/wmt16.en-de/prepare.sh [path to the downloaded zip file]
WIKITEXT-103
As the language model task has many additional codes, we place it in another branch: language-model
.
We follow the data pre-processing in fairseq. To download and preprocess the data, one can run
git checkout language-model
bash configs/wikitext-103/prepare.sh
Testing
For example, to test the models on WMT'14 En-Fr, one can run
configs/wmt14.en-fr/test.sh [path to the model checkpoints] [gpu-id] [test|valid]
For instance, to evaluate Lite Transformer on GPU 0 (with the BLEU score on test set of WMT'14 En-Fr), one can run
configs/wmt14.en-fr/test.sh embed496/ 0 test
We provide several pretrained models at the bottom. You can download the model and extract the file by
tar -xzvf [filename]
Training
We provided several examples to train Lite Transformer with this repo:
To train Lite Transformer on WMT'14 En-Fr (with 8 GPUs), one can run
python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
To train Lite Transformer with less GPUs, e.g. 4 GPUS, one can run
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32
In general, to train a model, one can run
python train.py [path to the data binary] --configs [path to config file] [override options]
Note that --update-freq
should be adjusted according to the GPU numbers (16 for 8 GPUs, 32 for 4 GPUs).
Distributed Training (optional)
To train Lite Transformer in distributed manner. For example on two GPU nodes with totally 16 GPUs.
# On host1
python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=2 --node_rank=0 \
--master_addr=host1 --master_port=8080 \
train.py data/binary/wmt14_en_fr \
--configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
--distributed-no-spawn \
--update-freq 8
# On host2
python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=2 --node_rank=1 \
--master_addr=host1 --master_port=8080 \
train.py data/binary/wmt14_en_fr \
--configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
--distributed-no-spawn \
--update-freq 8
Models
We provide the checkpoints for our Lite Transformer reported in the paper:
Dataset | #Mult-Adds | Test Score | Model and Test Set |
---|---|---|---|
WMT'14 En-Fr | 90M | 35.3 | download |
360M | 39.1 | download | |
527M | 39.6 | download | |
WMT'16 En-De | 90M | 22.5 | download |
360M | 25.6 | download | |
527M | 26.5 | download | |
CNN / DailyMail | 800M | 38.3 (R-L) | download |
WIKITEXT-103 | 1147M | 22.2 (PPL) | download |