• Stars
    star
    880
  • Rank 51,511 (Top 2 %)
  • Language
    Python
  • Created over 3 years ago
  • Updated 5 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Prefix-Tuning: Optimizing Continuous Prompts for Generation

Prefix Tuning

Files:

.
β”œβ”€β”€ gpt2                          # Code for GPT2 style autoregressive LM
β”‚   β”œβ”€β”€ train_e2e.py              # high-level scripts to train.
β”‚   β”œβ”€β”€ train_control.py          # code that implements prefix-tuning.
β”‚   β”œβ”€β”€ trainer_prefix.py         # trainer code for the training loop. 
β”‚   β”œβ”€β”€ run_language_modeling.py  # training code (contains data loading, model loading, and calls trainer)
β”‚   β”œβ”€β”€ gen.py                    # high-level scripts to decode. 
β”‚   └── run_generation.py         # decoding code. 
β”‚
β”œβ”€β”€ seq2seq                       # Code for encoder-decoder architecture
β”‚   β”œβ”€β”€ train_bart.py             # high-level scripts to train.
β”‚   β”œβ”€β”€ prefixTuning.py           # code that implements prefix-tuning.
β”‚   β”œβ”€β”€ finetune.py               # training code (contains data loading, model loading, and calls trainer)   
β”‚   β”œβ”€β”€ lightning_base.py         # helper code
β”‚   β”œβ”€β”€ utils.py                  # helper code
β”‚   └── callbacks.py              # helper code
└── ...

To run the code for GPT2 style autoregressive LM, the code is in gpt2/. This corresponds to the table-to-text experiments in the paper.

To run the code for encoder-decoder architecture like BART, the code is in seq2seq. This corresponds to the summarization experiments in the paper.

The two primary scripts I used to run my codes are gpt2/train_e2e.py (for table-to-text) and seq2seq/train_bart.py(for summarization). they are set to default of good hyperparameters, and can be used to tune hyperparameter :)


Setup:

cd transformer; pip install -e .


Train via prefix-tuning:

cd gpt2;

python train_e2e.py --optim_prefix yes --preseqlen 5 --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101
cd seq2seq; 

python train_bart.py --mode xsum --preseqlen 200 --do_train yes --fp16 yes --bsz 16  --epoch 30  --gradient_accumulation_step 3 --learning_rate 0.00005  --mid_dim 800

Other baseline approaches

cd gpt2;

python train_e2e.py --tuning_mode {finetune/adaptertune} --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101
cd seq2seq;

python train_e2e.py --tuning_mode finetune --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101

Decode:

cd gpt2;

python gen.py {data2text/webnlg/...} yes test {checkpoint_path} no
cd seq2seq; 

python train_bart.py --mode xsum --do_train no --prefix_model_path {checkpoint_path} --preseqlen {same as training} --mid_dim {same as training}

For details of the methods and results, please refer to our paper.

@misc{li2021prefixtuning,
      title={Prefix-Tuning: Optimizing Continuous Prompts for Generation}, 
      author={Xiang Lisa Li and Percy Liang},
      year={2021},
      eprint={2101.00190},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}