• Stars
    star
    173
  • Rank 220,124 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created about 4 years ago
  • Updated 6 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

The PyTorch implementation of fine-tuning the GPT-2(Generative Pre-trained Transformer 2) for dialogue generation.

gpt2-dialogue-generation-pytorch

This is a multi-turn chatbot project using the pre-trained GPT-2[1] introduced in How to build a State-of-the-Art Conversational AI with Transfer Learning[2].

Especially, this repository uses the GPT-2 Language Modeling Head model which has one additional linear layer to conduct Language Modeling task to consider the dialogue contexts and make a proper next response.

I did not include the persona information unlike the original version.



Arguments

Arguments for data loading

Argument Type Description Default
data_dir str The name of the parent directory where data files are stored. "data"
train_prefix str The prefix of the train data files' name. "train"
valid_prefix str The prefix of the validation data files' name. "valid"
train_frac float The ratio of the conversations to be included in the train set. 0.85
model_type str The model type of GPT-2. ("gpt2", "gpt2-medium", "gpt2-large", or "gpt2-xl") "gpt2"

Arguments for training

Argument Type Description Default
seed int The random seed. 0
data_dir str The name of the parent directory where data files are stored. "data"
train_prefix str The prefix of the train data files' name. "train"
valid_prefix str The prefix of the validation data files' name. "valid"
model_type str The model type of GPT-2. ("gpt2", "gpt2-medium", "gpt2-large", or "gpt2-xl") "gpt2"
bos_token str The BOS token. "<bos>"
sp1_token str The speaker1 token. "<sp1>"
sp2_token str The speaker2 token. "<sp2>"
gpu str The index of GPU to use. "0"
lr float The learning rate. 2e-5
warmup_ratio float The ratio of warmup steps to the total training steps. 0.1
batch_size int The batch size. 8
num_workers int The number of workers for data loading. 0
num_epochs int The number of total epochs. 10
max_len int The maximum length of input sequence. 1024
max_turns int The maximum number of dialogue histories to include. 5
ckpt_dir str The path for saved checkpoints. "saved_models"
ckpt_name str The default name for the trained model. (without extension) YOU MIGHT SPECIFY

Arguments for inference

Argument Type Description Default
seed int The random seed. 0
data_dir str The name of the parent directory where data files are stored. "data"
model_type str The model type of GPT-2. ("gpt2", "gpt2-medium", "gpt2-large", or "gpt2-xl") "gpt2"
bos_token str The BOS token. "<bos>"
sp1_token str The speaker1 token. "<sp1>"
sp2_token str The speaker2 token. "<sp2>"
gpu str The index of GPU to use. "0"
max_len int The maximum length of input sequence. 1024
max_turns int The maximum number of dialogue histories to include. 5
top_p float The top-p value for nucleus sampling decoding. 0.8
ckpt_dir str The path for saved checkpoints. "saved_models"
ckpt_name str The default name for the trained model. (without extension) YOU SHOULD SPECIFY
end_command str The command to stop the conversation when inferencing. "Abort!"


Datasets

By default, I propose the codes for downloading the datasets and preprocessing.

There are 4 types of the default datasets as follows.


  • DailyDialog[3]
  • EmpatheticDialogues[4]
  • Persona-Chat[5]
  • BlendedSkillTalk[6]


How to run

  1. Install all required packages.

    pip install -r requirements.txt

  2. Download & Preprocess all datasets.

    sh exec_load_data.sh

    After running it, you will have the following data directory structure if you follow the default argument setting.

    data
    â””--gpt2
        â””--train_utters.pickle
        â””--train_ids.pickle
        â””--valid_utters.pickle
        â””--valid_ids.pickle
    

  3. Run the following command to train the model.

    If you want to train it starting from a specific checkpoint, add the argument ckpt_name and make sure to notify the proper checkpoint name.

    sh exec_train.sh

  4. Run below command to conduct an inference with the trained model.

    This time, you are required to give a specific ckpt_name.

    sh exec_infer.sh


References

[1] Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.(http://www.persagen.com/files/misc/radford2019language.pdf)

[2] How to build a State-of-the-Art Conversational AI with Transfer Learning . (2019, May 9). (https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313)

[3] Li, Y., Su, H., Shen, X., Li, W., Cao, Z., & Niu, S. (2017). Dailydialog: A manually labelled multi-turn dialogue dataset. arXiv preprint arXiv:1710.03957. (https://arxiv.org/abs/1710.03957)

[4] Rashkin, H., Smith, E. M., Li, M., & Boureau, Y. L. (2018). Towards empathetic open-domain conversation models: A new benchmark and dataset. arXiv preprint arXiv:1811.00207. (https://arxiv.org/abs/1811.00207)

[5] Zhang, S., Dinan, E., Urbanek, J., Szlam, A., Kiela, D., & Weston, J. (2018). Personalizing dialogue agents: I have a dog, do you have pets too?. arXiv preprint arXiv:1801.07243. (https://arxiv.org/abs/1801.07243)

[6] Smith, E. M., Williamson, M., Shuster, K., Weston, J., & Boureau, Y. L. (2020). Can You Put it All Together: Evaluating Conversational Agents' Ability to Blend Skills. arXiv preprint arXiv:2004.08449. (https://arxiv.org/abs/2004.08449)

More Repositories

1

transformer-translator-pytorch

The PyTorch implementation of the transformer for machine translation.
Python
68
star
2

bert-crf-entity-extraction-pytorch

Entity extraction using BERT + CRF for single-tun / multi-turn setting in dialogues
Python
30
star
3

recosa-dialogue-generation-pytorch

The PyTorch implementation of ReCoSa(the Relevant Contexts with Self-attention) for dialogue generation using the multi-head attention and GRU.
Python
21
star
4

lstm-bayesian-optimization-pytorch

Bayesian Optimization implementation for text classifiction
Python
20
star
5

t5-dst-modified-pytorch

Modified version of T5-DST for Dialogue State Tracking.
Python
18
star
6

dialogue-sentence-bert-pytorch

DialogueSentenceBERT: SentenceBERT for More Representative Utterance Embedding via Pre-training on Dialogue Corpus
Python
7
star
7

intent-capsnet-pytorch

IntentCapsNet implementation in Pytorch
Python
3
star
8

labyrinth-simulator-kani

This is an open-sourced project for simulating the fantasy text-based adventure game: Jim Henson's Labyrinth: The Adventure Game grounded on large language models supported by kani.
Python
3
star
9

bert-ruber-kor-pytorch

The Korean implementation of BERT-RUBER: Better Automatic Evaluation of Open-Domain Dialogue Systems with Contextualized Embeddings.
Python
2
star
10

dqn-atari-tensorflow

This is a refactored version of Deep Q-Network for Atari.
Python
1
star
11

decoding-based-online-distillation-pytorch

Decoding-based Online Knowledge Distillation for a Lightweight Dialogue Understanding Model
Python
1
star
12

smart-reply-kor-pytorch

Korean Smart Reply system using IntentCapsnet
Python
1
star
13

algorithm-study

Practice for Algorithm with C++
C++
1
star
14

dialogue-error-correction-pytorch

Dialogue Error Correction for Dialogue State Tracking and Data Augmentation.
Python
1
star