• Stars
    star
    2,707
  • Rank 16,154 (Top 0.4 %)
  • Language
    Python
  • License
    MIT License
  • Created over 3 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A simple but complete full-attention transformer with a set of promising experimental features from various papers

x-transformers

PyPI version

A concise but fully-featured transformer, complete with a set of promising experimental features from various papers.

Install

$ pip install x-transformers

Usage

Full encoder / decoder

import torch
from x_transformers import XTransformer

model = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    tie_token_emb = True      # tie embeddings of encoder and decoder
)

src = torch.randint(0, 256, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 256, (1, 1024))

loss = model(src, tgt, mask = src_mask) # (1, 1024, 512)
loss.backward()

Decoder-only (GPT-like)

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()

model(x) # (1, 1024, 20000)

GPT3 would be approximately the following (but you wouldn't be able to run it anyways)

gpt3 = TransformerWrapper(
    num_tokens = 50000,
    max_seq_len = 2048,
    attn_layers = Decoder(
        dim = 12288,
        depth = 96,
        heads = 96,
        attn_dim_head = 128
    )
).cuda()

Encoder-only (BERT-like)

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 1024, 20000)

State of the art image classification (SimpleViT)

import torch
from x_transformers import ViTransformerWrapper, Encoder

model = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
    )
)

img = torch.randn(1, 3, 256, 256)
model(img) # (1, 1000)

Image -> caption

import torch
from x_transformers import ViTransformerWrapper, TransformerWrapper, Encoder, Decoder

encoder = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

decoder = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        cross_attend = True
    )
)

img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))

encoded = encoder(img, return_embeddings = True)
decoder(caption, context = encoded) # (1, 1024, 20000)

PaLI, state of the art language-vision model

import torch
from x_transformers import ViTransformerWrapper, XTransformer, Encoder

# PaLI composes of
# 1. vision transformer (ViTransformerWrapper) +
# 2. encoder-decoder transformer (XTransformer)

vit = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

pali = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024
)

# training data

img = torch.randn(1, 3, 256, 256)               # images
prompt = torch.randint(0, 256, (1, 1024))       # prompt
prompt_mask = torch.ones(1, 1024).bool()        # prompt text mask
output_text = torch.randint(0, 256, (1, 1024))  # target output text

# train

img_embeds = vit(
    img,
    return_embeddings = True
)

loss = pali(
    prompt,
    output_text,
    mask = prompt_mask,
    src_prepend_embeds = img_embeds             # will preprend image embeddings to encoder text embeddings before attention
)

loss.backward()

# do the above for many steps on a 17B parameter model
# attention is all you need

Dropouts

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    emb_dropout = 0.1,         # dropout after embedding
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        layer_dropout = 0.1,   # stochastic depth - dropout entire layer
        attn_dropout = 0.1,    # dropout post-attention
        ff_dropout = 0.1       # feedforward dropout
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Features

Flash Attention

What originally started off as a short paper from Markus Rabe culminated as a practical fused attention CUDA kernel, named Flash Attention by Tri Dao.

The technique processes the attention matrix in tiles, only keeping track of the running softmax and exponentiated weighted sums. By recomputing on the backwards pass in a tiled fashion, one is able to keep the memory linear with respect to sequence length. This allows a lot of recent models to be able to reach for longer context lengths without worrying about the memory bottleneck.

Other engineering decisions made by Tri Dao led to its enormous success, namely minimizing HBM accesses so that both the forwards and backwards outperform naive attention. In other words, flash attention is not only more memory efficient, but faster as well, making it a necessity for training transformers.

MetaAI has recently added the ability to use Tri Dao's CUDA kernel through the scaled_dot_product_attention function in Pytorch 2.0. (They also have a mem_efficient attention, which is identical to flash attention design, just that the tiles are traversed differently)

Llama was trained using Flash Attention. The only reason to avoid it is if you require operating on the attention matrix (dynamic positional bias, talking heads, residual attention).

You can use it in this repository by setting attn_flash to True and enjoy the immediate memory savings and increase in speed.

ex.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_flash = True # just set this to True if you have pytorch 2.0 installed
    )
)

Augmenting Self-attention with Persistent Memory

https://arxiv.org/abs/1907.01470

Proposes adding learned memory key / values prior to attention. They were able to remove feedforwards altogether and attain similar performance to the original transformers. I have found that keeping the feedforwards and adding the memory key / values leads to even better performance.

from x_transformers import Decoder, Encoder

enc = Encoder(
    dim = 512,
    depth = 6,
    heads = 8,
    attn_num_mem_kv = 16 # 16 memory key / values
)

Memory Transformers

https://arxiv.org/abs/2006.11527

Proposes adding learned tokens, akin to CLS tokens, named memory tokens, that is passed through the attention layers alongside the input tokens.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    num_memory_tokens = 20, # 20 memory tokens
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Transformers Without Tears

https://arxiv.org/abs/1910.05895

They experiment with alternatives to Layer normalization and found one that is both effective and simpler. Researchers have shared with me this leads to faster convergence.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_scalenorm = True # set to True to use for all layers
    )
)

You can also use the l2 normalized embeddings proposed as part of fixnorm. I have found it leads to improved convergence, when paired with small initialization (proposed by BlinkDL). The small initialization will be taken care of as long as l2norm_embed is set to True

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    l2norm_embed = True,    # set this to True for l2 normalized embedding + small init
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Along the same lines of l2 normalized embeddings, Huggingface's 175B parameter BLOOM also places a layernorm right after the embeddings and just before the tokens enter the attention layers. This was corroborated by Yandex's 100B parameter YaLM to stabilize training.

It is recommended you either have either l2norm_embed or post_emb_norm set to True but not both, as they probably serve the same purpose.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    post_emb_norm = True,    # set this to True to layernorm summed token + pos embeddings
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Root Mean Square Layer Normalization

https://arxiv.org/abs/1910.07467

The authors propose to replace layer normalization with a simpler alternative, without mean centering and the learned bias. An investigative paper found this to be the best performing normalization variant. It was also used in Deepmind's latest large language models, Retro and Gopher.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_rmsnorm = True # set to true to use for all layers
    )
)

July 2023 A linear attention paper has experiments to show that removing the learned multiplicative gamma led to no performance degradation. This simplifies the RMS normalization to a satisfying l2norm(x) * sqrt(dim).

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_simple_rmsnorm = True # set to true to use for all layers
    )
)

GLU Variants Improve Transformer

https://arxiv.org/abs/2002.05202

Noam Shazeer paper that explores gating in the feedforward, finding that simple gating with GELU leads to significant improvements. This variant also showed up in the latest mT5 architecture. You should always turn this on (I may eventually turn it on by default).

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_glu = True # set to true to use for all feedforwards
    )
)

The PaLM language model also chose to use the Swish GLU variant. You can turn this on by setting two flags

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_swish = True, # set this to True
        ff_glu = True    # set to true to use for all feedforwards
    )
)

No Bias in Feedforward

Starting with PaLM, there begun a trend to remove biases from the transformer all together. Boris Dayma has run a number of experiments that showed removing biases from feedforwards led to increased throughput without any loss of accuracy. This was corroborated by yet another paper investigating transformer architecture variants.

You can turn off the feedforward bias as follows

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_no_bias = True  # set this to True
    )
)

ReLUยฒ

https://arxiv.org/abs/2109.08668

This paper used neural architecture search and found an activation, Relu Squared, that is both simpler and performs better than GELU, in the autoregressive language model setting. I have confirmed this in my independent experiments. However, if one were using the GLU variant from above, GELU still performs better. Pending further corroboration.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_relu_squared = True
    )
)

Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection

https://arxiv.org/abs/1912.11637

This paper proposes an efficient way to sparsify attention by zeroing all dot-product query/key values not within the top k values. The show that this cheap method was as effective as other more expensive operations like sparsemax or entmax15. This technique comes with the cost of an extra hyperparameter (the top k values to keep). The paper recommends a value of k = 8

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_sparse_topk = 8 # keep only the top 8 values before attention (softmax)
    )
)

Talking-Heads Attention

https://arxiv.org/abs/2003.02436

A Noam Shazeer paper that proposes mixing information between heads pre and post attention (softmax). This comes with the cost of extra memory and compute.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_talking_heads = True  # turn on information exchange between attention heads
    )
)

One Write-Head Is All You Need

https://arxiv.org/abs/1911.02150

Yet another Noam Shazeer paper (he's a legend) that proposes to only have one head for the key / values, but multi-headed queries. This paper was largely ignored for a while, but recently validated at scale in AlphaCode as well as PaLM. It has the property of being memory efficient when decoding extremely large language models. You can use it with one keyword argument as shown below.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_one_kv_head = True
    )
)

Attention on Attention for Image Captioning

https://arxiv.org/abs/1908.06954

This paper proposes to add a gated linear unit at the end of the attention layer, further gated by the original queries. Although this is not widely used outside of visual question / answering, I suspect it should lead to improvements after seeing the success of the feedforward GLU variant.

Update: After some experimentation, I found this variant actually performs worse, but if it were to be modified to not concatenate the queries before gating, it performs much better. That is what we will be using in this repository.

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_on_attn = True  # gate output of attention layer, by queries
    )
)

Intra-attention Gating on Values

Alphafold2 had a peculiar variant of attention where they gate the aggregated values with the input, presumably to have the block have more control over the update.

A quick test shows a small but noticeable improvement, on about the same order as attention on attention.

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_gate_values = True  # gate aggregated values with the input
    )
)

Improving Transformer Models by Reordering their Sublayers

https://arxiv.org/abs/1911.03864

This paper proposes to break from the normal fixed pattern of alternating attention and feedforwards, but to have blocks of only attention at the beginning followed by blocks of feedforwards at the end. This was further corroborated by a paper by Nvidia that reduces the number of attention layers to be 1/3rd of the feedforwards without loss in performance.

The amount of interleaving is controlled by a "sandwich coefficient", which they found to be optimal at a value of 6.

You can experiment with this feature as shown below

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        sandwich_coef = 6  # interleave attention and feedforwards with sandwich coefficient of 6
    )
)

Understanding and Improving Transformer From a Multi-Particle Dynamic System Point of View

https://arxiv.org/abs/1906.02762

The authors propose to view the success of transformers from a dynamical systems point of view, and then proposes an improvement based on mathematics of that POV. Specifically, they propose to place the attention layer in between two feedforward layers. This was adopted by a paper using transformers for speech recognition, the Conformer.

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        macaron = True  # use macaron configuration
    )
)

T5's Simplified Relative Positional Encoding

https://arxiv.org/abs/1910.10683

T5 is one of the most successful encoder / decoder transformer architectures trained to date. They invented a new simplified relative positional encoding based on learned bias values that are added to the attention matrix pre-softmax. This bias is shared and injected into each attention layer. I have decided to include this because it offers a cheap way to have relative positional encoding (superior to absolute positional), and I have read papers that suggest having positional encoding added to each layer (vs only before the first) is beneficial.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rel_pos_bias = True  # adds relative positional bias to all attention layers, a la T5
    )
)

Residual Attention

https://arxiv.org/abs/2012.11747

This paper from Google proposes residualizing the pre-attention scores across all layers. At the cost of no extra parameters, they show improvement on top of regular attention networks. If you turn on this setting, be aware that the best results in the paper used post-normalization, in which case a learning warmup will be needed. The authors also reported that they could use a higher learning rate and get even better gains in the same amount of steps. (In the paper they use 2e-4 vs 1e-4 for vanilla transformer)

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        pre_norm = False,       # in the paper, residual attention had best results with post-layernorm
        residual_attn = True    # add residual attention
    )
)

I also tried residualizing cross attention and may have noticed an improvement in convergence. You can try it by setting the cross_residual_attn keyword to True

import torch
from x_transformers import XTransformer

model = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    dec_cross_residual_attn = True     # residualize cross attention
)

Deepnorm

It is well known that post-normalization transformers have trouble with stability, prompting the move to pre-normalization in recent years, even though the latter sacrifices performance.

This paper out of Microsoft research proposes a way to fix post-normalization stability. They achieve this by simply scaling the residual and proper initialization. They show they can train an one thousand layer transformer without stability issues, and achieve better results than pre-normalization.

This was recently validated in a 130B GLM model out of Tsinghua.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        deepnorm = True,     # set this to True to use deepnorm post-normalization configuration
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Transformer-XL recurrence

You can also do Transformer-XL recurrence, by simply passing in a max_mem_len in the TransformerWrapper class, and then making sure your Decoder has rel_pos_bias set to True.

Then, you can retrieve the memories at each step with the return_mems keyword and pass it to the next iteration.

import torch
from x_transformers import TransformerWrapper, Decoder

model_xl = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 512,
    max_mem_len = 2048,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rel_pos_bias = True
    )
)

seg1 = torch.randint(0, 20000, (1, 512))
seg2 = torch.randint(0, 20000, (1, 512))
seg3 = torch.randint(0, 20000, (1, 512))

logits1, mems1  = model_xl(seg1, return_mems = True)
logits2, mems2  = model_xl(seg2, mems = mems1, return_mems = True)
logits3, mems3  = model_xl(seg3, mems = mems2, return_mems = True)

Setting up the logic for training and sampling from transformer xl can be a bit overwhelming. This repository offers a simple wrapper that should make this easy, with the XLAutoregressiveWrapper.

# pass in the above model_xl

xl_wrapper = XLAutoregressiveWrapper(model_xl)

seg = torch.randint(0, 20000, (1, 4096)).cuda()  # sequence exceeding max length, automatically segmented and memory managed

loss = xl_wrapper(seg)
loss.backward()

# then, after much training

prime = seg[:, :1024]   # if prime exceeds max length, memory will be caught up before generating

generated = xl_wrapper.generate(prime, 4096)  # (1, 4096)

Enhanced recurrence

This paper proposes a simple technique to enhance the range of Transformer-XL. They simply route the memory segment of a layer to the layer below it, for the next recurrent step. You can enable this by setting shift_mem_down = 1. You can also shift down arbitrary number of layers by setting this value to > 1.

import torch
from x_transformers import TransformerWrapper, Decoder

model_xl = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 512,
    max_mem_len = 2048,
    shift_mem_down = 1,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rotary_pos_emb = True
    )
)

seg1 = torch.randint(0, 20000, (1, 512))
seg2 = torch.randint(0, 20000, (1, 512))
seg3 = torch.randint(0, 20000, (1, 512))

logits1, mems1  = model_xl(seg1, return_mems = True)
logits2, mems2  = model_xl(seg2, mems = mems1, return_mems = True) # mems1 of layer N are automatically routed to the layer N-1

Gated residual

https://arxiv.org/abs/1910.06764

The authors propose gating the residual connections in the transformer network and demonstrate increased stability and performance for Transformer-XL in a variety of reinforcement learning tasks.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    max_mem_len = 2048,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 16,
        gate_residual = True
    )
)

Rotary Positional Embeddings

Developed in Beijing, this new technique quickly gained interest in the NLP circles. In short, it allows you to endow the transformer with relative positional embeddings at the cost of no learned parameters. You apply a rotary operation to the queries and keys prior to their dot product in attention. The big idea is injecting positions through rotations.

Highly recommend that you have this turned on whenever you are working on an ordered sequence.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rotary_pos_emb = True  # turns on rotary positional embeddings
    )
)

Update (12/2022): Rotary embedding has since been hugely successful, widely adopted in many large language models, including the largest in the world, PaLM. However, it has been uncovered in the ALiBi paper that rotary embeddings cannot length extrapolate well. This was recently addressed in a Microsoft research paper. They propose a way to unobtrusively add the same decay as in ALiBi, and found that this resolves the extrapolation problem. You can use it in this repository by setting rotary_xpos = True. Like ALiBi, it would enforce the attention to be local. You can set the receptive field with rotary_xpos_scale_base value, which defaults to 512

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rotary_xpos = True   # modified rotary to extrapolate well beyond length at which it was trained
    )
)

Dynamic Positional Bias

This technique bears roots from the field of vision transformers, where researchers are trying to have relative positions generalize to larger resolutions (without having to retrain the entire network). It was used in two recent papers, CrossFormer, as well as SwinV2.

Charles Foster first tried this for a language model, and found that it works. Later on Eric Engelhart produced experimental results that show the same type of extrapolation holds, even for 1d sequences.

Eric trained at sequence lengths of 128, and showed that it generalized well to 1024. In addition, he showed that linear positions was better than log (used in SwinV2), for language.

Linear distances

Log distances

Negative control - Sinusoidal

More of Eric's experimental results can be found here

You can use this type of relative position if you wish to train at smaller sequence lengths and have it generalize to longer ones, for both autoregressive and bidirectional models.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        dynamic_pos_bias = True,                # set this to True
        dynamic_pos_bias_log_distance = False   # whether to use log distance, as in SwinV2
    )
)

ALiBi Positional Embedding

This paper proposes to simply apply a static linear bias to the attention matrix. The authors show this is not only effective as a relative positional encoding, but also allows the attention net to extrapolate to greater sequences length than what it was trained on, for autoregressive language models.

This repository also offers a bidirectional variant (nonsymmetric), proposed by the authors here. However, this is untested. If you need bidirectional length extrapolation, the safest option would be Dynamic Position Bias

Update: It may be that ALiBi enforces a strong local attention across the heads, and may hinder it from attending at distances greater than 1k. To avoid any issues with global message passing, I've decided to introduce another hyperparameter alibi_num_heads, so one can specify less heads for the ALiBi bias

Update: There are reports that ALiBi outperform Rotary embeddings for pretraining and downstream fine-tuning.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        alibi_pos_bias = True, # turns on ALiBi positional embedding
        alibi_num_heads = 4    # only use ALiBi for 4 out of the 8 heads, so other 4 heads can still attend far distances
    )
)

Shifted Tokens

An independent researcher has found that shifting a subset of the feature dimension along the sequence dimension by 1 token helps with convergence (Time-mixing). I have tested this for the autoregressive case and can confirm that it leads to greatly improved convergence. This also lines up with the results of some papers in the vision domain.

To use it, simply set shift_tokens = 1 (or to whatever number of shifts you desire). The feature dimension will be divided by shift_tokens + 1 and then each chunk will be shifted [0, shift_tokens] respectively

Update: new experiments by @sdtblck suggests this may only work for character-level training

Update: after more experiments, it seems that in the context of BPE encoding, with rotary turned on, there is no benefit to shifting. for character-level training, shifting may still improve a tiny bit

Update: When doing BPE encoded tokens, it seems that shift of 2 will bottleneck the dimensions (divided by 5). It is recommended you always do a shift of 1, unless if you are working with character level.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        shift_tokens = 1
    )
)

If you want finer control over how much is shifted per block (whether attention or feedforward), simply pass in a tuple of size that is equal to the number of layers.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        shift_tokens = (1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0) # 12 blocks, attention and feedforward alternating, with progressively less shifting
    )
)

Sandwich Norm

This technique first made an appearance in the CoqView paper, a Chinese version of the famous text-to-image transformer DALL-E. They propose, when using pre-layernorm, to add an extra layernorm to all the branch outputs. I have found this to be very effective for a number of projects, when facing instability during training.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        sandwich_norm = True # set this to True
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

ResiDual

This Microsoft paper proposes yet another normalization configuration, combining both pre and post layernorm. They claim this hybridization reduces representation collapse (known to be an issue with pre-layernorm with increasing depth), while maintaining stability and reducing vanishing gradients (issues with post-layernorm). Initial experiments on my end show it to work no worse than pre-layernorm or sandwich norm. More study needed by the public to see if this is actually a winning technique.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        resi_dual = True,               # set this to True
        resi_dual_scale = 0.1           # in appendix, they said on fp16 the prenorm residual is prone to overflow. they claim by scaling it at each layer by a factor, it would prevent the overflow, and keep results the same (as layernorms are invariant to scaling of the input)
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Normformer

This paper uncovers an issue with pre-norm transformers where gradients are mismatched between the early and later layers. They propose 4 changes, of which I will be offering 3.

The first change is to offer per head scaling after aggregating the values in attention. My experiments show a slight improvement in convergence.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_head_scale = True  # set this to True
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

The second change is an extra layernorm right after the activation in the feedforward. I have also verified a slight improvement, at the cost of extra compute.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_post_act_ln = True # set this to True
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

For the residual scaling, you simply have to set scale_residual = True. I have noticed slight improvements, but occasional instability as well, so use with caution.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        scale_residual = True # set this to True
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

The last change is a layernorm right after the outwards projection in attention. This is actually identical to the sandwich norm proposed by the Coqview paper, so you can use this by simply setting sandwich_norm = True, although it would also add it to the feedforward layer.

Cosine Sim Attention

This paper proposes to l2 normalize the queries and keys along the head dimension before the dot product (cosine similarity), with the additional change of the scale being learned rather than static. The normalization prevents the attention operation from overflowing, and removes any need for numerical stability measures prior to softmax. Both are perennial problems when training transformers.

This was validated at scale recently by the training of a 3B parameter vision transformer. The SwinV2 paper also proposes to change the pre-layernorm to a post-layernorm for further stability.

I have validated that this works just as well as dot product attention in an autoregressive setting, if one were to initialize the temperature as proposed in the QK-norm paper (as a function of the sequence length).

This flavor of attention also has a connection to sparse distributed memory. [youtube talk]

Update: I have discovered a way to remove the learned temperature altogether, by grouping the feature dimension and doing l2-normalization on each group. This allows the queries and keys to have a similarity that is upper bounded by the number of groups. A group size of 8 or 16 was sufficient in my tests. Decided to name this technique "Grouped QK Normalization". The drawback is that I believe an attention head dimension 32 is too small to use this tactic (a dimension often used in vision)

You can use it as follows

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_qk_norm = True,       # set this to True
        attn_qk_norm_groups = 8    # number of groups in the feature dimension for l2norm, similarity scores will be bounded between [-group, group]. determines how sharp the attention can be
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Another update: Simply scaling the cosine similarity (group of 1) with a fixed constant (10) may work too

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_qk_norm = True,       # set to True
        attn_qk_norm_scale = 10    # new scale on the similarity, with groups of 1
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

QK RMSNorm

Update: Google Brain has proven out something similar to cosine sim attention in a 22B parameter model. In their papers, they have analysis showing that the normalization resulted in not only extra stability, but also better results in the end (due to less need to adjust learning rate when increasing parameter count).

We are nearing the point of wiping out a source of transformer training instability with one simple intervention, in my opinion. The only slight difference in the paper is that they still have a learned scale across the feature dimension (per use of rmsnorm). Not sure how critical this is, but just to make sure we don't miss anything, I will include this here. You can use this by setting qk_norm_dim_scale = True

Update: Counterpoint from Tim Dettmers

Update 2: Counter to Tim's assertion that outliers are needed, and potentially even some solutions

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8,
        attn_qk_norm = True,
        attn_qk_norm_dim_scale = True # set this to True, in addition to `attn_qk_norm = True`
    )
)

x = torch.randint(0, 256, (1, 1024))
model(x)

Turning off absolute positional embedding

A number of papers have hinted that causal transformers (Decoder) can learn absolute positions in the absence of added embeddings of any sort. This was recently thoroughly investigated here. You can turn off the absolute positional embedding by setting use_abs_pos_emb = False in the TransformerWrapper

Given PaLM, the trend going forward may be to forgo absolute positional embedding (again, for causal transformers only), and add relative positional embeddings with RoPE, ALiBi, etc.

Update: This paper shows that in the absence of any engineered absolute or relative positional embeddings, decoders can generate implicit positions, and even length generalize better than solutions of the past. They were unaware of dynamic positional bias, however.

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    use_abs_pos_emb = False,   # set this to False
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Forgetful Causal Mask

This paper shows convincing results that one can combine masking (from masked language modeling) with autoregressive training, leading to significantly better results.

You can use this by setting the mask_prob on the AutoregressiveWrapper class

import torch
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

model = AutoregressiveWrapper(
    model,
    mask_prob = 0.15  # in paper, they use 15%, same as BERT
).cuda()

# mock data

x = torch.randint(0, 20000, (1, 1024)).cuda()

# derive cross entropy loss, masking all taken care of

loss = model(x)
loss.backward()

Miscellaneous

Cross Attention

import torch
from x_transformers import Encoder, CrossAttender

enc = Encoder(dim = 512, depth = 6)
model = CrossAttender(dim = 512, depth = 6)

nodes = torch.randn(1, 1, 512)
node_masks = torch.ones(1, 1).bool()

neighbors = torch.randn(1, 5, 512)
neighbor_masks = torch.ones(1, 5).bool()

encoded_neighbors = enc(neighbors, mask = neighbor_masks)
model(nodes, context = encoded_neighbors, mask = node_masks, context_mask = neighbor_masks) # (1, 1, 512)

Pass in continuous values

import torch
from x_transformers import ContinuousTransformerWrapper, Decoder

model = ContinuousTransformerWrapper(
    dim_in = 32,
    dim_out = 100,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

x = torch.randn((1, 1024, 32))

model(x) # (1, 1024, 100)

You can also train a transformer that accepts continuous values autoregressively easily, in the same scheme as done successfully in this paper

import torch
from x_transformers import ContinuousTransformerWrapper, Decoder
from x_transformers import ContinuousAutoregressiveWrapper

model = ContinuousTransformerWrapper(
    dim_in = 777,
    dim_out = 777,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

# wrap it with the continuous autoregressive wrapper

model = ContinuousAutoregressiveWrapper(model)

# mock data

x = torch.randn((1, 1024, 777))
mask = torch.ones(1, 1024).bool()

# train on a lot of data above

loss = model(x, mask = mask)
loss.backward

# then generate

start_emb = torch.randn(1, 777)
generated = model.generate(start_emb, 17) # (17, 777)

Citations

@misc{vaswani2017attention,
    title   = {Attention Is All You Need},
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{DBLP:journals/corr/abs-1907-01470,
    author    = {Sainbayar Sukhbaatar and
               Edouard Grave and
               Guillaume Lample and
               Herv{\'{e}} J{\'{e}}gou and
               Armand Joulin},
    title     = {Augmenting Self-attention with Persistent Memory},
    journal   = {CoRR},
    volume    = {abs/1907.01470},
    year      = {2019},
    url       = {http://arxiv.org/abs/1907.01470}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}    
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}
@misc{burtsev2020memory,
    title   = {Memory Transformer}, 
    author  = {Mikhail S. Burtsev and Grigory V. Sapunov},
    year    = {2020},
    eprint  = {2006.11527},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{zhao2019explicit,
    title   = {Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection}, 
    author  = {Guangxiang Zhao and Junyang Lin and Zhiyuan Zhang and Xuancheng Ren and Qi Su and Xu Sun},
    year    = {2019},
    eprint  = {1912.11637},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{correia2019adaptively,
    title   = {Adaptively Sparse Transformers},
    author  = {Gonรงalo M. Correia and Vlad Niculae and Andrรฉ F. T. Martins},
    year    = {2019},
    eprint  = {1909.00015},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{shazeer2020talkingheads,
    title   = {Talking-Heads Attention}, 
    author  = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    year    = {2020},
    eprint  = {2003.02436},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{press2020improving,
    title   = {Improving Transformer Models by Reordering their Sublayers}, 
    author  = {Ofir Press and Noah A. Smith and Omer Levy},
    year    = {2020},
    eprint  = {1911.03864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{lu2019understanding,
    title   = {Understanding and Improving Transformer From a Multi-Particle Dynamic System Point of View}, 
    author  = {Yiping Lu and Zhuohan Li and Di He and Zhiqing Sun and Bin Dong and Tao Qin and Liwei Wang and Tie-Yan Liu},
    year    = {2019},
    eprint  = {1906.02762},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{ke2020rethinking,
    title     = {Rethinking Positional Encoding in Language Pre-training},
    author    = {Guolin Ke and Di He and Tie-Yan Liu},
    year      = {2020},
    eprint    = {2006.15595},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{dosovitskiy2020image,
    title   = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
    author  = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
    year    = {2020},
    eprint  = {2010.11929},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{huang2019attention,
    title   = {Attention on Attention for Image Captioning},
    author  = {Lun Huang and Wenmin Wang and Jie Chen and Xiao-Yong Wei},
    year    = {2019},
    eprint  = {1908.06954},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{raffel2020exploring,
    title   = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer}, 
    author  = {Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},
    year    = {2020},
    eprint  = {1910.10683},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{martins-etal-2020-sparse,
    title   = "Sparse Text Generation",
    author  = "Martins, Pedro Henrique  and
        Marinho, Zita  and
        Martins, Andr{\'e} F. T.",
    booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
    month   = nov,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.emnlp-main.348"
}
@misc{he2020realformer,
    title   = {RealFormer: Transformer Likes Residual Attention},
    author  = {Ruining He and Anirudh Ravula and Bhargav Kanagal and Joshua Ainslie},
    year    = {2020},
    eprint  = {2012.11747},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{carion2020endtoend,
    title   = {End-to-End Object Detection with Transformers},
    author  = {Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
    year    = {2020},
    eprint  = {2005.12872},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{press2021ALiBi,
    title   = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
    author  = {Ofir Press and Noah A. Smith and Mike Lewis},
    year    = {2021},
    url     = {https://ofir.io/train_short_test_long.pdf}
}
@misc{parisotto2019stabilizing,
    title     = {Stabilizing Transformers for Reinforcement Learning},
    author    = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
    year      = {2019},
    eprint    = {1910.06764},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{narang2021transformer,
    title       = {Do Transformer Modifications Transfer Across Implementations and Applications?},
    author      = {Sharan Narang and Hyung Won Chung and Yi Tay and William Fedus and Thibault Fevry and Michael Matena and Karishma Malkan and Noah Fiedel and Noam Shazeer and Zhenzhong Lan and Yanqi Zhou and Wei Li and Nan Ding and Jake Marcus and Adam Roberts and Colin Raffel},
    year        = {2021},
    eprint      = {2102.11972},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{zhang2019root,
    title   = {Root Mean Square Layer Normalization},
    author  = {Biao Zhang and Rico Sennrich},
    year    = {2019},
    eprint  = {1910.07467},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{Qin2023ScalingTT,
    title   = {Scaling TransNormer to 175 Billion Parameters},
    author  = {Zhen Qin and Dong Li and Weigao Sun and Weixuan Sun and Xuyang Shen and Xiaodong Han and Yunshen Wei and Baohong Lv and Fei Yuan and Xiao Luo and Y. Qiao and Yiran Zhong},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:260203124}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Chen2023ExtendingCW,
    title   = {Extending Context Window of Large Language Models via Positional Interpolation},
    author  = {Shouyuan Chen and Sherman Wong and Liangjian Chen and Yuandong Tian},
    year    = {2023}
}
@inproceedings{Sun2022ALT,
  title     = {A Length-Extrapolatable Transformer},
  author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
  year      = {2022}
}
@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = {aug},
    year         = {2021},
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578}
}
@misc{csordรกs2021devil,
    title   = {The Devil is in the Detail: Simple Tricks Improve Systematic Generalization of Transformers},
    author  = {Rรณbert Csordรกs and Kazuki Irie and Jรผrgen Schmidhuber},
    year    = {2021},
    eprint  = {2108.12284},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{so2021primer,
    title   = {Primer: Searching for Efficient Transformers for Language Modeling}, 
    author  = {David R. So and Wojciech Maล„ke and Hanxiao Liu and Zihang Dai and Noam Shazeer and Quoc V. Le},
    year    = {2021},
    eprint  = {2109.08668},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer}, 
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}
@misc{henry2020querykey,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Pawar and Yuxuan Chen},
    year    = {2020},
    eprint  = {2010.04245},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Haviv2022TransformerLM,
    title   = {Transformer Language Models without Positional Encodings Still Learn Positional Information},
    author  = {Adi Haviv and Ori Ram and Ofir Press and Peter Izsak and Omer Levy},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2203.16634}
}
@article{chowdhery2022PaLM,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Chowdhery, Aakanksha et al},
    year    = {2022}
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam M. Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@article{Wang2022DeepNetST,
    title   = {DeepNet: Scaling Transformers to 1, 000 Layers},
    author  = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2203.00555}
}
@misc{schlag2020enhancing,
    title   = {Enhancing the Transformer with explicit relational encoding for math problem solving},
    author  = {Imanol Schlag and Paul Smolensky and Roland Fernandez and Nebojsa Jojic and J{\"u}rgen Schmidhuber and Jianfeng Gao},
    year    = {2020},
    url     = {https://openreview.net/forum?id=B1xfElrKPr}
}
@article{Liu2022FCMFC,
    title   = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners},
    author  = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.13432}
}
@inproceedings{Huang2016DeepNW,
    title   = {Deep Networks with Stochastic Depth},
    author  = {Gao Huang and Yu Sun and Zhuang Liu and Daniel Sedra and Kilian Q. Weinberger},
    booktitle = {European Conference on Computer Vision},
    year    = {2016}
}
@inproceedings{Hua2022TransformerQI,
    title   = {Transformer Quality in Linear Time},
    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
    booktitle = {International Conference on Machine Learning},
    year    = {2022}
}
@article{Chang2022MaskGITMG,
    title   = {MaskGIT: Masked Generative Image Transformer},
    author  = {Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11305-11315}
}
@article{Lezama2022ImprovedMI,
    title   = {Improved Masked Image Generation with Token-Critic},
    author  = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.04439}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Xie2023ResiDualTW,
  title     = {ResiDual: Transformer with Dual Residual Connections},
  author    = {Shufang Xie and Huishuai Zhang and Junliang Guo and Xu Tan and Jiang Bian and Hany Hassan Awadalla and Arul Menezes and Tao Qin and Rui Yan},
  journal   = {ArXiv},
  year      = {2023},
  volume    = {abs/2304.14802}
}
@inproceedings{Dehghani2023ScalingVT,
    title   = {Scaling Vision Transformers to 22 Billion Parameters},
    author  = {Mostafa Dehghani and Josip Djolonga and Basil Mustafa and Piotr Padlewski and Jonathan Heek and Justin Gilmer and Andreas Steiner and Mathilde Caron and Robert Geirhos and Ibrahim M. Alabdulmohsin and Rodolphe Jenatton and Lucas Beyer and Michael Tschannen and Anurag Arnab and Xiao Wang and Carlos Riquelme and Matthias Minderer and Joan Puigcerver and Utku Evci and Manoj Kumar and Sjoerd van Steenkiste and Gamaleldin F. Elsayed and Aravindh Mahendran and Fisher Yu and Avital Oliver and Fantine Huot and Jasmijn Bastings and Mark Collier and Alexey A. Gritsenko and Vighnesh Birodkar and Cristina Nader Vasconcelos and Yi Tay and Thomas Mensink and Alexander Kolesnikov and Filip Paveti'c and Dustin Tran and Thomas Kipf and Mario Luvci'c and Xiaohua Zhai and Daniel Keysers and Jeremiah Harmsen and Neil Houlsby},
    year    = {2023}
}
@article{Beyer2022BetterPV,
    title   = {Better plain ViT baselines for ImageNet-1k},
    author  = {Lucas Beyer and Xiaohua Zhai and Alexander Kolesnikov},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.01580}
}
@article{Liu2023EfficientViTME,
    title   = {EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention},
    author  = {Xinyu Liu and Houwen Peng and Ningxin Zheng and Yuqing Yang and Han Hu and Yixuan Yuan},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.07027}
}
@article{Kazemnejad2023TheIO,
    title   = {The Impact of Positional Encoding on Length Generalization in Transformers},
    author  = {Amirhossein Kazemnejad and Inkit Padhi and Karthikeyan Natesan Ramamurthy and Payel Das and Siva Reddy},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.19466}
}
@misc{bloc97-2023
    title   = {NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.}
    author  = {/u/bloc97},
    url     = {https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/}
}

solve intelligence... then use that to solve everything else. - Demis Hassabis

More Repositories

1

vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Python
13,633
star
2

DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
Python
10,770
star
3

imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
Python
7,675
star
4

PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
Python
7,559
star
5

DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
Python
5,132
star
6

deep-daze

Simple command line tool for text to image generation using OpenAI's CLIP and Siren (Implicit neural representation network). Technique was originally created by https://twitter.com/advadnoun
Python
4,387
star
7

denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
Python
3,959
star
8

stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
Python
3,433
star
9

musiclm-pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch
Python
2,934
star
10

big-sleep

A simple command line tool for text to image generation, using OpenAI's CLIP and a BigGAN. Technique was originally created by https://twitter.com/advadnoun
Python
2,446
star
11

audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
Python
2,179
star
12

reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
Python
1,870
star
13

lion-pytorch

๐Ÿฆ Lion, new optimizer discovered by Google Brain using genetic algorithms that is purportedly better than Adam(w), in Pytorch
Python
1,859
star
14

toolformer-pytorch

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
Python
1,846
star
15

make-a-video-pytorch

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch
Python
1,807
star
16

gigagan-pytorch

Implementation of GigaGAN, new SOTA GAN out of Adobe. Culmination of nearly a decade of research into GANs
Python
1,542
star
17

lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
Python
1,526
star
18

lambda-networks

Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute
Python
1,516
star
19

byol-pytorch

Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch
Python
1,497
star
20

alphafold2

To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released
Python
1,476
star
21

self-rewarding-lm-pytorch

Implementation of the training framework proposed in Self-Rewarding Language Model, from MetaAI
Python
1,154
star
22

naturalspeech2-pytorch

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch
Python
1,141
star
23

flamingo-pytorch

Implementation of ๐Ÿฆฉ Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch
Python
1,108
star
24

soundstorm-pytorch

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch
Python
1,091
star
25

video-diffusion-pytorch

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch
Python
1,072
star
26

CoCa-pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch
Python
945
star
27

performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
Python
937
star
28

perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
Python
935
star
29

mlp-mixer-pytorch

An All-MLP solution for Vision, from Google AI
Python
833
star
30

RETRO-pytorch

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch
Python
813
star
31

vector-quantize-pytorch

Vector Quantization, in Pytorch
Python
810
star
32

PaLM-pytorch

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways
Python
804
star
33

muse-maskgit-pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
Python
781
star
34

phenaki-pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch
Python
694
star
35

x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers
Python
635
star
36

bottleneck-transformer-pytorch

Implementation of Bottleneck Transformer in Pytorch
Python
632
star
37

TimeSformer-pytorch

Implementation of TimeSformer from Facebook AI, a pure attention-based solution for video classification
Python
613
star
38

memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
Python
596
star
39

MEGABYTE-pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
Python
573
star
40

nuwa-pytorch

Implementation of NรœWA, state of the art attention network for text to video synthesis, in Pytorch
Python
531
star
41

point-transformer-pytorch

Implementation of the Point Transformer layer, in Pytorch
Python
518
star
42

parti-pytorch

Implementation of Parti, Google's pure attention-based text-to-image neural network, in Pytorch
Python
502
star
43

tab-transformer-pytorch

Implementation of TabTransformer, attention network for tabular data, in Pytorch
Python
485
star
44

voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
Python
470
star
45

linear-attention-transformer

Transformer based on a variant of attention that is linear complexity in respect to sequence length
Python
468
star
46

meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
Python
430
star
47

g-mlp-pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Python
391
star
48

siren-pytorch

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
Python
377
star
49

recurrent-memory-transformer-pytorch

Implementation of Recurrent Memory Transformer, Neurips 2022 paper, in Pytorch
Python
371
star
50

egnn-pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch
Python
367
star
51

ema-pytorch

A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
Python
356
star
52

enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
Python
352
star
53

magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
Python
346
star
54

memory-efficient-attention-pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(nยฒ) Memory"
Python
328
star
55

FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
Python
323
star
56

robotic-transformer-pytorch

Implementation of RT1 (Robotic Transformer) in Pytorch
Python
320
star
57

medical-chatgpt

Implementation of ChatGPT, but tailored towards primary care medicine, with the reward being able to collect patient histories in a thorough and efficient manner and come up with a reasonable differential diagnosis
Python
309
star
58

bit-diffusion

Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch
Python
308
star
59

slot-attention

Implementation of Slot Attention from GoogleAI
Python
303
star
60

iTransformer

Unofficial implementation of iTransformer - SOTA Time Series Forecasting using Attention networks, out of Tsinghua / Ant group
Python
300
star
61

transformer-in-transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Python
277
star
62

axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
Python
273
star
63

conformer

Implementation of the convolutional module from the Conformer paper, for use in Transformers
Python
272
star
64

q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
Python
266
star
65

mixture-of-experts

A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
Python
264
star
66

magic3d-pytorch

Implementation of Magic3D, Text to 3D content synthesis, in Pytorch
Python
258
star
67

routing-transformer

Fully featured implementation of Routing Transformer
Python
251
star
68

classifier-free-guidance-pytorch

Implementation of Classifier Free Guidance in Pytorch, with emphasis on text conditioning, and flexibility to include multiple text embedding models
Python
248
star
69

Adan-pytorch

Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch
Python
241
star
70

x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings
Python
241
star
71

deformable-attention

Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"
Python
237
star
72

segformer-pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Python
227
star
73

perfusion-pytorch

Implementation of Key-Locked Rank One Editing, from Nvidia AI
Python
224
star
74

sinkhorn-transformer

Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention
Python
222
star
75

equiformer-pytorch

Implementation of the Equiformer, SE3/E3 equivariant attention network that reaches new SOTA, and adopted for use by EquiFold for protein folding
Python
220
star
76

pixel-level-contrastive-learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch
Python
220
star
77

spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
Python
220
star
78

ring-attention-pytorch

Explorations into Ring Attention, from Liu et al. at Berkeley AI
Python
218
star
79

local-attention

An implementation of local windowed attention for language modeling
Python
216
star
80

natural-speech-pytorch

Implementation of the neural network proposed in Natural Speech, a text-to-speech generator that is indistinguishable from human recordings for the first time, from Microsoft Research
Python
215
star
81

BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
Python
213
star
82

CoLT5-attention

Implementation of the conditionally routed attention in the CoLT5 architecture, in Pytorch
Python
212
star
83

se3-transformer-pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
Python
211
star
84

block-recurrent-transformer-pytorch

Implementation of Block Recurrent Transformer - Pytorch
Python
198
star
85

Mega-pytorch

Implementation of Mega, the Single-head Attention with Multi-headed EMA architecture that currently holds SOTA on Long Range Arena
Python
198
star
86

triton-transformer

Implementation of a Transformer, but completely in Triton
Python
195
star
87

jax2torch

Use Jax functions in Pytorch
Python
194
star
88

halonet-pytorch

Implementation of the ๐Ÿ˜‡ Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Python
193
star
89

st-moe-pytorch

Implementation of ST-Moe, the latest incarnation of MoE after years of research at Brain, in Pytorch
Python
190
star
90

flash-cosine-sim-attention

Implementation of fused cosine similarity attention in the same style as Flash Attention
Cuda
190
star
91

attention

This repository will house a visualization that will attempt to convey instant enlightenment of how Attention works to someone not working in artificial intelligence, with 3Blue1Brown as inspiration
HTML
189
star
92

simple-hierarchical-transformer

Experiments around a simple idea for inducing multiple hierarchical predictive model within a GPT
Python
189
star
93

med-seg-diff-pytorch

Implementation of MedSegDiff in Pytorch - SOTA medical segmentation using DDPM and filtering of features in fourier space
Python
187
star
94

electra-pytorch

A simple and working implementation of Electra, the fastest way to pretrain language models from scratch, in Pytorch
Python
186
star
95

recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
Python
185
star
96

unet-stylegan2

A Pytorch implementation of Stylegan2 with UNet Discriminator
Python
182
star
97

res-mlp-pytorch

Implementation of ResMLP, an all MLP solution to image classification, in Pytorch
Python
181
star
98

PaLM-jax

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
Python
180
star
99

glom-pytorch

An attempt at the implementation of Glom, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns), for emergent part-whole heirarchies from data
Python
178
star
100

soft-moe-pytorch

Implementation of Soft MoE, proposed by Brain's Vision team, in Pytorch
Python
174
star