• Stars
    star
    166
  • Rank 227,748 (Top 5 %)
  • Language
    Shell
  • License
    MIT License
  • Created over 5 years ago
  • Updated over 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code for the paper "Are Sixteen Heads Really Better than One?"

Are Sixteen Heads Really Better than One?

This repository contains code to reproduce the experiments in our paper Are Sixteen Heads Really Better than One?.

Prerequisite

First, you will need python >=3.6 with pytorch>=1.0. Then, clone our forks of fairseq (for MT experiments) and pytorch-pretrained-BERT (for BERT):

# Fairseq
git clone https://github.com/pmichel31415/fairseq
# Pytorch pretrained BERT
git clone https://github.com/pmichel31415/pytorch-pretrained-BERT
cd pytorch-pretrained-BERT
git checkout paul
cd ..

If you are running into issues with pytorch-pretrained-BERT (because you have another version installed globally for instance), check out this work around (thanks @insop).

You will also need sacrebleu to evaluate BLEU score (pip install sacrebleu).

Ablation experiments

BERT

Running

bash experiments/BERT/heads_ablation.sh MNLI

Will fine-tune a pretrained BERT on MNLI (stored in ./models/MNLI) and perform the individual head ablation experiment from Section 3.1 in the paper alternatively you can run the experiment with CoLA, MRCP or SST-2 as a task in place of MNLI.

MT

You can obtain the pretrained WMT model from this link from the fairseq repo now this link. Use the Moses tokenizer and subword-nmt in conjunction to the BPE codes provided with the pretrained model to prepair any input file you want. Then run:

bash experiments/MT/wmt_ablation.sh $BPE_SEGMENTED_SRC_FILE $DETOKENIZED_REF_FILE

Systematic Pruning Experiments

BERT

To iteratively prune 10% heads in order of increasing importance run

bash experiments/BERT/heads_pruning.sh MNLI --normalize_pruning_by_layer

This will reuse the BERT model fine-tuned if you have run the ablation experiment before (otherwise it'll just fine-tune it for you). The output of this is very verbose, but you can get the gist of the result by calling grep "strategy\|results" -A1 on the output.

WMT

Similarly, just run:

bash experiments/MT/prune_wmt.sh $BPE_SEGMENTED_SRC_FILE $DETOKENIZED_REF_FILE

You might want to change the paths in the experiment files to point to the binarized fairseq dataset on whic you want to estimate importance scores.