Transformer Chatbot with TensorFlow 2
Build an end-to-end chatbot with Transformer in TensorFlow 2. Checkout my tutorial on blog.tensorflow.org.
Updates
- 16 June 2022:
- Update
setup.sh
script to install Apple Silicon version of TensorFlow 2.9 (only use this if you're feeling adventurous). - Updated the two custom layers,
PositionalEncoding
andMultiHeadAttentionLayer
, to allow model saving viamodel.save()
ortf.keras.models.save_model()
. train.py
showcase how to callmodel.save()
andtf.keras.models.load_model()
.
- Update
- 8 Dec 2020: Updated support to TensorFlow 2.3.1 and TensorFlow Datasets 4.1.0
- 18 Jan 2020: Added notebook with Google Colab TPU support in TensorFlow 2.1.
Packages
- TensorFlow 2.9.1
- TensorFlow Datasets
Setup
- create new anaconda environment and initialize environment
chatbot
conda create -n chatbot python=3.8 conda activate chatbot
- run installation script
sh setup.sh
- Note: the script would install CUDA and cuDNN via conda if installing on a Linux system, or
tensorflow-metal
for devices with Apple Silicon (Note that there are tons of bugs with TensorFlow on Apple Silicon GPU, e.g. Adam optimizer does not work).
Dataset
- We will use the conversations in movies and TV shows provided by Cornell Movie-Dialogs Corpus, which contains more than 220 thousands conversational exchanges between more than 10k pairs of movie characters, as our dataset.
- We pre-process our dataset in the following order:
- Extract
max_samples
conversation pairs into list ofquestions
andanswers
. - Pre-process each sentence by removing special characters in each sentence.
- Build tokenizer (map text to ID and ID to text) using TensorFlow Datasets SubwordTextEncoder.
- Tokenize each sentence and add
start_token
andend_token
to indicate the start and end of each sentence. - Filter out sentence that has more than
max_length
tokens. - Pad tokenized sentences to
max_length
- Extract
- Check dataset.py implementation.
Model
- check model.py for the implementation of Multi-Headed Attention, Positional Encoding and Transformer.
Run
- check all available flags and hyper-parameters
python main.py --help
python train.py --output_dir runs/save_model --batch_size 256 --epochs 50 --max_samples 50000
- the final trained model will be saved to
runs/save_model
.
Samples
input: where have you been?
output: i m not talking about that .
input: it's a trap!
output: no , it s not .