Prefix Tuning
Files:
.
βββ gpt2 # Code for GPT2 style autoregressive LM
β βββ train_e2e.py # high-level scripts to train.
β βββ train_control.py # code that implements prefix-tuning.
β βββ trainer_prefix.py # trainer code for the training loop.
β βββ run_language_modeling.py # training code (contains data loading, model loading, and calls trainer)
β βββ gen.py # high-level scripts to decode.
β βββ run_generation.py # decoding code.
β
βββ seq2seq # Code for encoder-decoder architecture
β βββ train_bart.py # high-level scripts to train.
β βββ prefixTuning.py # code that implements prefix-tuning.
β βββ finetune.py # training code (contains data loading, model loading, and calls trainer)
β βββ lightning_base.py # helper code
β βββ utils.py # helper code
β βββ callbacks.py # helper code
βββ ...
To run the code for GPT2 style autoregressive LM, the code is in gpt2/
. This corresponds to the table-to-text experiments in the paper.
To run the code for encoder-decoder architecture like BART, the code is in seq2seq
. This corresponds to the summarization experiments in the paper.
The two primary scripts I used to run my codes are gpt2/train_e2e.py
(for table-to-text) and seq2seq/train_bart.py
(for summarization).
they are set to default of good hyperparameters, and can be used to tune hyperparameter :)
Setup:
cd transformer; pip install -e .
Train via prefix-tuning:
cd gpt2;
python train_e2e.py --optim_prefix yes --preseqlen 5 --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101
cd seq2seq;
python train_bart.py --mode xsum --preseqlen 200 --do_train yes --fp16 yes --bsz 16 --epoch 30 --gradient_accumulation_step 3 --learning_rate 0.00005 --mid_dim 800
Other baseline approaches
cd gpt2;
python train_e2e.py --tuning_mode {finetune/adaptertune} --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101
cd seq2seq;
python train_e2e.py --tuning_mode finetune --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101
Decode:
cd gpt2;
python gen.py {data2text/webnlg/...} yes test {checkpoint_path} no
cd seq2seq;
python train_bart.py --mode xsum --do_train no --prefix_model_path {checkpoint_path} --preseqlen {same as training} --mid_dim {same as training}
For details of the methods and results, please refer to our paper.
@misc{li2021prefixtuning,
title={Prefix-Tuning: Optimizing Continuous Prompts for Generation},
author={Xiang Lisa Li and Percy Liang},
year={2021},
eprint={2101.00190},
archivePrefix={arXiv},
primaryClass={cs.CL}
}