gpt2-dialogue-generation-pytorch
This is a multi-turn chatbot project using the pre-trained GPT-2[1] introduced in How to build a State-of-the-Art Conversational AI with Transfer Learning[2].
Especially, this repository uses the GPT-2 Language Modeling Head model which has one additional linear layer to conduct Language Modeling task to consider the dialogue contexts and make a proper next response.
I did not include the persona information unlike the original version.
Arguments
Arguments for data loading
Argument | Type | Description | Default |
---|---|---|---|
data_dir |
str |
The name of the parent directory where data files are stored. | "data" |
train_prefix |
str |
The prefix of the train data files' name. | "train" |
valid_prefix |
str |
The prefix of the validation data files' name. | "valid" |
train_frac |
float |
The ratio of the conversations to be included in the train set. | 0.85 |
model_type |
str |
The model type of GPT-2. ("gpt2" , "gpt2-medium" , "gpt2-large" , or "gpt2-xl" ) |
"gpt2" |
Arguments for training
Argument | Type | Description | Default |
---|---|---|---|
seed |
int |
The random seed. | 0 |
data_dir |
str |
The name of the parent directory where data files are stored. | "data" |
train_prefix |
str |
The prefix of the train data files' name. | "train" |
valid_prefix |
str |
The prefix of the validation data files' name. | "valid" |
model_type |
str |
The model type of GPT-2. ("gpt2" , "gpt2-medium" , "gpt2-large" , or "gpt2-xl" ) |
"gpt2" |
bos_token |
str |
The BOS token. | "<bos>" |
sp1_token |
str |
The speaker1 token. | "<sp1>" |
sp2_token |
str |
The speaker2 token. | "<sp2>" |
gpu |
str |
The index of GPU to use. | "0" |
lr |
float |
The learning rate. | 2e-5 |
warmup_ratio |
float |
The ratio of warmup steps to the total training steps. | 0.1 |
batch_size |
int |
The batch size. | 8 |
num_workers |
int |
The number of workers for data loading. | 0 |
num_epochs |
int |
The number of total epochs. | 10 |
max_len |
int |
The maximum length of input sequence. | 1024 |
max_turns |
int |
The maximum number of dialogue histories to include. | 5 |
ckpt_dir |
str |
The path for saved checkpoints. | "saved_models" |
ckpt_name |
str |
The default name for the trained model. (without extension) | YOU MIGHT SPECIFY |
Arguments for inference
Argument | Type | Description | Default |
---|---|---|---|
seed |
int |
The random seed. | 0 |
data_dir |
str |
The name of the parent directory where data files are stored. | "data" |
model_type |
str |
The model type of GPT-2. ("gpt2" , "gpt2-medium" , "gpt2-large" , or "gpt2-xl" ) |
"gpt2" |
bos_token |
str |
The BOS token. | "<bos>" |
sp1_token |
str |
The speaker1 token. | "<sp1>" |
sp2_token |
str |
The speaker2 token. | "<sp2>" |
gpu |
str |
The index of GPU to use. | "0" |
max_len |
int |
The maximum length of input sequence. | 1024 |
max_turns |
int |
The maximum number of dialogue histories to include. | 5 |
top_p |
float |
The top-p value for nucleus sampling decoding. | 0.8 |
ckpt_dir |
str |
The path for saved checkpoints. | "saved_models" |
ckpt_name |
str |
The default name for the trained model. (without extension) | YOU SHOULD SPECIFY |
end_command |
str |
The command to stop the conversation when inferencing. | "Abort!" |
Datasets
By default, I propose the codes for downloading the datasets and preprocessing.
There are 4 types of the default datasets as follows.
How to run
-
Install all required packages.
pip install -r requirements.txt
-
Download & Preprocess all datasets.
sh exec_load_data.sh
After running it, you will have the following data directory structure if you follow the default argument setting.
data â””--gpt2 â””--train_utters.pickle â””--train_ids.pickle â””--valid_utters.pickle â””--valid_ids.pickle
-
Run the following command to train the model.
If you want to train it starting from a specific checkpoint, add the argument
ckpt_name
and make sure to notify the proper checkpoint name.sh exec_train.sh
-
Run below command to conduct an inference with the trained model.
This time, you are required to give a specific
ckpt_name
.sh exec_infer.sh
References
[1] Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.(http://www.persagen.com/files/misc/radford2019language.pdf)
[2] How to build a State-of-the-Art Conversational AI with Transfer Learning . (2019, May 9). (https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313)
[3] Li, Y., Su, H., Shen, X., Li, W., Cao, Z., & Niu, S. (2017). Dailydialog: A manually labelled multi-turn dialogue dataset. arXiv preprint arXiv:1710.03957. (https://arxiv.org/abs/1710.03957)
[4] Rashkin, H., Smith, E. M., Li, M., & Boureau, Y. L. (2018). Towards empathetic open-domain conversation models: A new benchmark and dataset. arXiv preprint arXiv:1811.00207. (https://arxiv.org/abs/1811.00207)
[5] Zhang, S., Dinan, E., Urbanek, J., Szlam, A., Kiela, D., & Weston, J. (2018). Personalizing dialogue agents: I have a dog, do you have pets too?. arXiv preprint arXiv:1801.07243. (https://arxiv.org/abs/1801.07243)
[6] Smith, E. M., Williamson, M., Shuster, K., Weston, J., & Boureau, Y. L. (2020). Can You Put it All Together: Evaluating Conversational Agents' Ability to Blend Skills. arXiv preprint arXiv:2004.08449. (https://arxiv.org/abs/2004.08449)