BERT-based NLI model
This project includes a natural language inference (NLI) model, developed by fine-tuning Transformers on the SNLI, MultiNLI and Hans datasets. This project has been used to develop our paper Adapting by Pruning: A Case Study on BERT(appendix). Please cite this paper when you use this project.
Highlighted Features
- Models based on BERT-(base, large) and ALBERT-(base,large)
- Implemented using PyTorch (1.5.0)
- Low memory requirements: Using mixed-precision (nvidia apex) and checkpoint to reduce the GPU memory consumption; training the bert/albert-large model only requires around 6GB GPU memory (with batch size 8).
- Easy inerface: A user-friendly interface is provided to use the trained models
- All source code: All source code for training and testing the models is provided
Contact person: Yang Gao, [email protected]
https://sites.google.com/site/yanggaoalex/home
Don't hesitate to send me an e-mail or report an issue, if something is broken or if you have further questions.
Use the trained NLI model
- The pretrained models are downloaded to output/ (after you run get_data.py in datasets/)
- An example is presented in example.py:
from bert_nli import BertNLIModel
model = BertNLIModel('output/bert-base.state_dict')
sent_pairs = [('The lecturer committed plagiarism.','He was promoted.')]
label, _= model(sent_pairs)
print(label)
The output of the above example is:
['contradiction']
How to set up
- Python3.7
- Install all packages in requirement.txt.
pip3 install -r requirements.txt
- Download the SNLI and MultiNLI data as well as the trained model with the commands below
cd datasets/
python get_data.py
- (Optional) Our code supports the use of the Hans dataset to train the model, in order to prevent the BERT model from exploiting spurious features to make NLI predictions. To use the Hans dataset, download heuristics_train_set.txt and heuristics_evaluation_set.txt from here, and put them to datasets/Hans/. During training/test, add argument --hans 1.
- (Optional) To use mixed precision training (nvidia apex), run the code below:
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
- All code is tested on a desktop with a single nVidia RTX 2080 card (8GB memory), running Python 3.7 on Ubuntu 18.04 LTS.
Train NLI models
- Run train.py and specify what Transformer model you would like to fine tune:
python train.py --bert_type bert-large --check_point 1
Option "--check_point 1" means that we will use the checkpoint technique during training. Without using it, the RTX2080 card (8GB memory) is not able to accommodate the bert-large model. But note that, by using checkpoint, it usually takes longer time to train the model.
The trained model (that has the best performance on the dev set) will be saved to directory output/.
Test the performance of the trained models
- To test the performance of a trained model on the MNLI and SNLI dev sets, run the command below:
python test_trained_model.py --bert_type bert-large
Performance of Trained Models
BERT-base
Accuracy: 0.8608.
Contradiction | Entail | Neutral | |
---|---|---|---|
Precision | 0.8791 | 0.8955 | 0.8080 |
Recall | 0.8755 | 0.8658 | 0.8403 |
F1 | 0.8773 | 0.8804 | 0.8239 |
BERT-large
Accuracy: 0.8739
Contradiction | Entail | Neutral | |
---|---|---|---|
Precision | 0.8992 | 0.8988 | 0.8233 |
Recall | 0.8895 | 0.8802 | 0.8508 |
F1 | 0.8944 | 0.8894 | 0.8369 |
ALBERT-large
Accuracy: 0.8743
Contradiction | Entail | Neutral | |
---|---|---|---|
Precision | 0.8907 | 0.8967 | 0.8335 |
Recall | 0.9006 | 0.8812 | 0.8397 |
F1 | 0.8957 | 0.8889 | 0.8366 |
License
Apache License Version 2.0