• Stars
    star
    880
  • Rank 51,881 (Top 2 %)
  • Language
    Jupyter Notebook
  • License
    MIT License
  • Created about 4 years ago
  • Updated almost 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

My implementation of the original transformer model (Vaswani et al.). I've additionally included the playground.py file for visualizing otherwise seemingly hard concepts. Currently included IWSLT pretrained models.

The Original Transformer (PyTorch) ๐Ÿ’ป = ๐ŸŒˆ

This repo contains PyTorch implementation of the original transformer paper (๐Ÿ”— Vaswani et al.).
It's aimed at making it easy to start playing and learning about transformers.

Table of Contents

What are transformers

Transformers were originally proposed by Vaswani et al. in a seminal paper called Attention Is All You Need.

You probably heard of transformers one way or another. GPT-3 and BERT to name a few well known ones ๐Ÿฆ„. The main idea is that they showed that you don't have to use recurrent or convolutional layers and that simple architecture coupled with attention is super powerful. It gave the benefit of much better long-range dependency modeling and the architecture itself is highly parallelizable (๐Ÿ’ป๐Ÿ’ป๐Ÿ’ป) which leads to better compute efficiency!

Here is how their beautifully simple architecture looks like:

Understanding transformers

This repo is supposed to be a learning resource for understanding transformers as the original transformer by itself is not a SOTA anymore.

For that purpose the code is (hopefully) well commented and I've included the playground.py where I've visualized a couple of concepts which are hard to explain using words but super simple once visualized. So here we go!

Positional Encodings

Can you parse this one in a glimpse of the eye?

Neither can I. Running the visualize_positional_encodings() function from playground.py we get this:

Depending on the position of your source/target token you "pick one row of this image" and you add it to it's embedding vector, that's it. They could also be learned, but it's just more fancy to do it like this, obviously! ๐Ÿค“

Custom Learning Rate Schedule

Similarly can you parse this one in O(1)?

Noup? So I thought, here it is visualized:

It's super easy to understand now. Now whether this part was crucial for the success of transformer? I doubt it. But it's cool and makes things more complicated. ๐Ÿค“ (.set_sarcasm(True))

Note: model dimension is basically the size of the embedding vector, baseline transformer used 512, the big one 1024

Label Smoothing

First time you hear of label smoothing it sounds tough but it's not. You usually set your target vocabulary distribution to a one-hot. Meaning 1 position out of 30k (or whatever your vocab size is) is set to 1. probability and everything else to 0.

In label smoothing instead of placing 1. on that particular position you place say 0.9 and you evenly distribute the rest of the "probability mass" over the other positions (that's visualized as a different shade of purple on the image above in a fictional vocab of size 4 - hence 4 columns)

Note: Pad token's distribution is set to all zeros as we don't want our model to predict those!

Aside from this repo (well duh) I would highly recommend you go ahead and read this amazing blog by Jay Alammar!

Machine translation

Transformer was originally trained for the NMT (neural machine translation) task on the WMT-14 dataset for:

  • English to German translation task (achieved 28.4 BLEU score)
  • English to French translation task (achieved 41.8 BLEU score)

What I did (for now) is I trained my models on the IWSLT dataset, which is much smaller, for the English-German language pair, as I speak those languages so it's easier to debug and play around.

I'll also train my models on WMT-14 soon, take a look at the todos section.


Anyways! Let's see what this repo can practically do for you! Well it can translate!

Some short translations from my German to English IWSLT model:

Input: Ich bin ein guter Mensch, denke ich. ("gold": I am a good person I think)
Output: ['<s>', 'I', 'think', 'I', "'m", 'a', 'good', 'person', '.', '</s>']
or in human-readable format: I think I'm a good person.

Which is actually pretty good! Maybe even better IMO than Google Translate's "gold" translation.


There are of course failure cases like this:

Input: Hey Alter, wie geht es dir? (How is it going dude?)
Output: ['<s>', 'Hey', ',', 'age', 'how', 'are', 'you', '?', '</s>']
or in human-readable format: Hey, age, how are you?

Which is actually also not completely bad! Because:

  • First of all the model was trained on IWSLT (TED like conversations)
  • "Alter" is a colloquial expression for old buddy/dude/mate but it's literal meaning is indeed age.

Similarly for the English to German model.

Setup

So we talked about what transformers are, and what they can do for you (among other things).
Let's get this thing running! Follow the next steps:

  1. git clone https://github.com/gordicaleksa/pytorch-original-transformer
  2. Open Anaconda console and navigate into project directory cd path_to_repo
  3. Run conda env create from project directory (this will create a brand new conda environment).
  4. Run activate pytorch-transformer (for running scripts from your console or set the interpreter in your IDE)

That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies.
It may take a while as I'm automatically downloading SpaCy's statistical models for English and German.


PyTorch pip package will come bundled with some version of CUDA/cuDNN with it, but it is highly recommended that you install a system-wide CUDA beforehand, mostly because of the GPU drivers. I also recommend using Miniconda installer as a way to get conda on your system. Follow through points 1 and 2 of this setup and use the most up-to-date versions of Miniconda and CUDA/cuDNN for your system.

Usage

Option 1: Jupyter Notebook

Just run jupyter notebook from you Anaconda console and it will open the session in your default browser.
Open The Annotated Transformer ++.ipynb and you're ready to play!


Note: if you get DLL load failed while importing win32api: The specified module could not be found
Just do pip uninstall pywin32 and then either pip install pywin32 or conda install pywin32 should fix it!

Option 2: Use your IDE of choice

You just need to link the Python environment you created in the setup section.

Training

To run the training start the training_script.py, there is a couple of settings you will want to specify:

  • --batch_size - this is important to set to a maximum value that won't give you CUDA out of memory
  • --dataset_name - Pick between IWSLT and WMT14 (WMT14 is not advisable until I add multi-GPU support)
  • --language_direction - Pick between E2G and G2E

So an example run (from the console) would look like this:
python training_script.py --batch_size 1500 --dataset_name IWSLT --language_direction G2E

The code is well commented so you can (hopefully) understand how the training itself works.

The script will:

  • Dump checkpoint *.pth models into models/checkpoints/
  • Dump the final *.pth model into models/binaries/
  • Download IWSLT/WMT-14 (the first time you run it and place it under data/)
  • Dump tensorboard data into runs/, just run tensorboard --logdir=runs from your Anaconda
  • Periodically write some training metadata to the console

Note: data loading is slow in torch text, and so I've implemented a custom wrapper which adds the caching mechanisms and makes things ~30x faster! (it'll be slow the first time you run stuff)

Inference (Translating)

The second part is all about playing with the models and seeing how they translate!
To get some translations start the translation_script.py, there is a couple of settings you'll want to set:

  • --source_sentence - depending on the model you specify this should either be English/German sentence
  • --model_name - one of the pretrained model names: iwslt_e2g, iwslt_g2e or your model(*)
  • --dataset_name - keep this in sync with the model, IWSLT if the model was trained on IWSLT
  • --language_direction - keep in sync, E2G if the model was trained to translate from English to German

(*) Note: after you train your model it'll get dumped into models/binaries see what it's name is and specify it via the --model_name parameter if you want to play with it for translation purpose. If you specify some of the pretrained models they'll automatically get downloaded the first time you run the translation script.

I'll link IWSLT pretrained model links here as well: English to German and German to English.

That's it you can also visualize the attention check out this section. for more info.

Evaluating NMT models

I tracked 3 curves while training:

  • training loss (KL divergence, batchmean)
  • validation loss (KL divergence, batchmean)
  • BLEU-4

BLEU is an n-gram based metric for quantitatively evaluating the quality of machine translation models.
I used the BLEU-4 metric provided by the awesome nltk Python module.

Current results, models were trained for 20 epochs (DE stands for Deutch i.e. German in German ๐Ÿค“):

Model BLEU score Dataset
Baseline transformer (EN-DE) 27.8 IWSLT val
Baseline transformer (DE-EN) 33.2 IWSLT val
Baseline transformer (EN-DE) x WMT-14 val
Baseline transformer (DE-EN) x WMT-14 val

I got these using greedy decoding so it's a pessimistic estimate, I'll add beam decoding soon.

Important note: Initialization matters a lot for the transformer! I initially thought that other implementations using Xavier initialization is again one of those arbitrary heuristics and that PyTorch default init will do - I was wrong:

You can see here 3 runs, the 2 lower ones used PyTorch default initialization (one used mean for KL divergence loss and the better one used batchmean), whereas the upper one used Xavier uniform initialization!


Idea: you could potentially also periodically dump translations for a reference batch of source sentences.
That would give you some qualitative insight into how the transformer is doing, although I didn't do that.
A similar thing is done when you have hard time quantitatively evaluating your model like in GANs and NST fields.

Tracking using Tensorboard

The above plot is a snippet from my Azure ML run but when I run stuff locally I use Tensorboard.

Just run tensorboard --logdir=runs from your Anaconda console and you can track your metrics during the training.

Visualizing attention

You can use the translation_script.py and set the --visualize_attention to True to additionally understand what your model was "paying attention to" in the source and target sentences.

Here are the attentions I get for the input sentence Ich bin ein guter Mensch, denke ich.

These belong to layer 6 of the encoder. You can see all of the 8 multi-head attention heads.

And this one belongs to decoder layer 6 of the self-attention decoder MHA (multi-head attention) module.
You can notice an interesting triangular pattern which comes from the fact that target tokens can't look ahead!

The 3rd type of MHA module is the source attending one and it looks similar to the plot you saw for the encoder.
Feel free to play with it at your own pace!

Note: there are obviously some bias problems with this model but I won't get into that analysis here

Hardware requirements

You really need a decent hardware if you wish to train the transformer on the WMT-14 dataset.

The authors took:

  • 12h on 8 P100 GPUs to train the baseline model and 3.5 days to train the big one.

If my calculations are right that amounts to ~19 epochs (100k steps, each step had ~25000 tokens and WMT-14 has ~130M src/trg tokens) for the baseline and 3x that for the big one (300k steps).

On the other hand it's much more feasible to train the model on the IWSLT dataset. It took me:

  • 13.2 min/epoch (1500 token batch) on my RTX 2080 machine (8 GBs of VRAM)
  • ~34 min/epoch (1500 token batch) on Azure ML's K80s (24 GBs of VRAM)

I could have pushed K80s to 3500+ tokens/batch but had some CUDA out of memory problems.

Todos:

Finally there are a couple more todos which I'll hopefully add really soon:

  • Multi-GPU/multi-node training support (so that you can train a model on WMT-14 for 19 epochs)
  • Beam decoding (turns out it's not that easy to implement this one!)
  • BPE and shared source-target vocab (I'm using SpaCy now)

The repo already has everything it needs, these are just the bonus points. I've tested everything from environment setup, to automatic model download, etc.

Video learning material

If you're having difficulties understanding the code I did an in-depth overview of the paper in this video:

A deep dive into the attention is all you need paper

I have some more videos which could further help you understand transformers:

Acknowledgements

I found these resources useful (while developing this one):

I found some inspiration for the model design in the The Annotated Transformer but I found it hard to understand, and it had some bugs. It was mainly written with researchers in mind. Hopefully this repo opens up the understanding of transformers to the common folk as well! ๐Ÿค“

Citation

If you find this code useful, please cite the following:

@misc{Gordiฤ‡2020PyTorchOriginalTransformer,
  author = {Gordiฤ‡, Aleksa},
  title = {pytorch-original-transformer},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/gordicaleksa/pytorch-original-transformer}},
}

Connect with me

If you'd love to have some more AI-related content in your life ๐Ÿค“, consider:

Licence

License: MIT

More Repositories

1

pytorch-GAT

My implementation of the original GAT paper (Veliฤkoviฤ‡ et al.). I've additionally included the playground.py file for visualizing the Cora dataset, GAT embeddings, an attention mechanism, and entropy histograms. I've supported both Cora (transductive) and PPI (inductive) examples!
Jupyter Notebook
2,253
star
2

get-started-with-JAX

The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
Jupyter Notebook
546
star
3

pytorch-GANs

My implementation of various GAN (generative adversarial networks) architectures like vanilla GAN (Goodfellow et al.), cGAN (Mirza et al.), DCGAN (Radford et al.), etc.
Python
366
star
4

Open-NLLB

Effort to open-source NLLB checkpoints.
Python
364
star
5

pytorch-deepdream

PyTorch implementation of DeepDream algorithm (Mordvintsev et al.). Additionally I've included playground.py to help you better understand basic concepts behind the algo.
Jupyter Notebook
352
star
6

pytorch-neural-style-transfer

Reconstruction of the original paper on neural style transfer (Gatys et al.). I've additionally included reconstruction scripts which allow you to reconstruct only the content or the style of the image - for better understanding of how NST works.
Python
343
star
7

stable_diffusion_playground

Playing around with stable diffusion. Generated images are reproducible because I save the metadata and latent information. You can generate and then later interpolate between the images of your choice.
Python
203
star
8

pytorch-learn-reinforcement-learning

A collection of various RL algorithms like policy gradients, DQN and PPO. The goal of this repo will be to make it a go-to resource for learning about RL. How to visualize, debug and solve RL problems. I've additionally included playground.py for learning more about OpenAI gym, etc.
Python
140
star
9

pytorch-neural-style-transfer-johnson

Reconstruction of the fast neural style transfer (Johnson et al.). Some portions of the paper have been improved by the follow-up work like the instance normalization, etc. Checkout transformer_net.py's header for details.
Python
110
star
10

serbian-llm-eval

Serbian LLM Eval.
Python
81
star
11

pytorch-naive-video-neural-style-transfer

Create naive (no temporal loss) NST for videos with person segmentation. Just place your videos in data/, run and you get your stylized and segmented videos.
Python
73
star
12

OpenGemini

Effort to open-source 10.5 trillion parameter Gemini model.
17
star
13

gordicaleksa

GitHub's new feature: repo with the same name as your GitHub name initialized with README.md will show on your landing page!
12
star
14

digital-image-processing

Projects I did for the Digital Image Processing course on my university
MATLAB
7
star
15

streamlit_playground

Simple Streamlit app.
Python
4
star
16

Open-NLLB-stopes

A library for preparing data for machine translation research (monolingual preprocessing, bitext mining, etc.) for the Open-NLLB effort.
Python
3
star
17

MachineLearningMicrosoftPetnica

Problems I solved for Microsoft ML summer camp in Petnica, Serbia
C++
3
star
18

competitive_programming

Contains algorithms and snippets I found useful when solving problems for TopCoder, Google Code Jam etc.
C++
2
star
19

slovenian-llm-eval

Slovenian LLM Eval.
Python
2
star
20

MicrosoftBubbleCup2018

My solutions for Bubble Cup 2018
C++
1
star
21

.dotfiles

Configuration files for my vim editor, bash etc.
Shell
1
star
22

GoogleCodeJam2018

My solutions for Google Code Jam 2018
C++
1
star