• Stars
    star
    194
  • Rank 200,219 (Top 4 %)
  • Language Cuda
  • License
    MIT License
  • Created over 2 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of fused cosine similarity attention in the same style as Flash Attention

Dive into Deep Learning, redone by Quanta Magazine

Flash Cosine Similarity Attention

Implementation of fused cosine similarity attention in the same style as Flash Attention. The observation is that by adopting l2 normalized queries and keys, you no longer need to keep track of the row maximums for numerical stability. This greatly simplifies the flash attention algorithm, assuming cosine similarity attention comes at no generalization cost.

In other words, stable, fast, memory efficient, and longer context attention with no downsides.

Update: Unfortunately, Robin's experiments showed much worse evaluation FID scores not reflected in the loss. Pending more experiments. Use this library with caution.

Update 2: The only saving grace would be to use grouped l2norm, which could potentially allow for more expressivity. If anyone can evaluate this technique on their generative work and obtain some FID scores, would be much appreciated.

Update 3: An approach similar to cosine sim attention has been proven at scale, with a 22B parameter vision model from Brain.

Status (wip)

At the moment, autoregressive and variable lengthed sequences should be faster across all architectures. For sequences longer than 2048, it will also be memory efficient where regular attention would not.

However, for non-autoregressive without masking, the architecture is still slower on A100 for F16. The aim is to get it to perform faster on A100 forwards and backwards for both F32 and F16, as shared memory is not fully exploited yet.

Older graphic cards without enough shared memory, one will have to gauge the tradeoff of memory efficiency and speed depending on the sequence length being trained at.

Appreciation

  • Arthur Hennequin for coaching me through my first CUDA kernel, and for coding up a simple reference implementation, which helped me to bootstrap the first kernel that comes within reasonable performance to baseline. This work would not have been possible without his expertise.

  • Boris Dayma and Robin Rombach for running experiments the simplified cosine sim attention with fixed scaling on some significant text-to-image models and verifying that it indeeds perform just as well as regular attention.

  • Markus Rabe for penning the paper that showed attention does not require O(nĀ²) memory, and Tri Dao for putting it all together in a CUDA kernel implementation for regular attention, demonstrating superiority in speed using the tiled approach minimizing HBM accesses (and for figuring out dO * O == dP * P for backwards pass). Would not have been able to complete my pilgrimage looking for the ultimate attention formulation without their discoveries.

  • Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research

Install

$ pip install flash-cosine-sim-attention

Usage

Self Attention

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 1024, 64).cuda()
v = torch.randn(1, 8, 1024, 64).cuda()

out = flash_cosine_sim_attention(q, k, v)  # (1, 8, 1024, 64)

Cross attention

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 2048, 64).cuda()
v = torch.randn(1, 8, 2048, 64).cuda()

out = flash_cosine_sim_attention(q, k, v) # (1, 8, 1024, 64)

With key / value masking

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 2048, 64).cuda()
v = torch.randn(1, 8, 2048, 64).cuda()

mask = torch.ones(1, 2048).bool().cuda()

out = flash_cosine_sim_attention(q, k, v, mask = mask) # (1, 8, 1024, 64)

Autoregressive

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(4, 8, 1024, 64).cuda()
k = torch.randn(4, 8, 1024, 64).cuda()
v = torch.randn(4, 8, 1024, 64).cuda()

out = flash_cosine_sim_attention(q, k, v, causal = True)  # (4, 8, 1024, 64)

Miscellaneous

Single-headed key / values (Shazeer et al & used in PaLM)

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(4, 8, 1024, 64).cuda()
k = torch.randn(4, 1024, 64).cuda()
v = torch.randn(4, 1024, 64).cuda()

out = flash_cosine_sim_attention(q, k, v, causal = True)  # (4, 8, 1024, 64)

If you need to do operations on the queries and keys in between the l2norm and the actual attention step, just set l2norm_qk = False

ex.

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention, l2norm_tensors

q = torch.randn(4, 8, 1024, 64).cuda()
k = torch.randn(4, 1024, 64).cuda()
v = torch.randn(4, 1024, 64).cuda()

q, k = l2norm_tensors(q, k)

# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch

out = flash_cosine_sim_attention(q, k, v, l2norm_qk = False)  # (4, 8, 1024, 64)

Cross attention with causal works as expected - (caching of keys and values in autoregressive during inference, or transformer-xl like training)

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 2048, 64).cuda()
v = torch.randn(1, 8, 2048, 64).cuda()

out = flash_cosine_sim_attention(q, k, v, causal = True) # (1, 8, 1024, 64)

If you have batch and head dimensions merged, that is ok

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(32, 1024, 64).cuda()
k = torch.randn(32, 2048, 64).cuda()
v = torch.randn(32, 2048, 64).cuda()

out = flash_cosine_sim_attention(q, k, v, causal = True) # (32, 1024, 64)

Supported head dimensions

  • 16 - f32

  • 32

  • 64

  • 96

  • 128

  • 16 -f16

  • 80 - in progress

Todo

  • bfloat16 support, use sfinae as recommended by Arthur

  • stream from qk_mma to shared memory in chunks to calculate out mma, see if freed smem can be used for caching more

  • support O(n) 1d dynamic positional bias

  • figure out why smem fragment caching would lead to performance degrade, it does not make sense

  • think about use of logsumexp - works but extra log lead to degraded perf

  • prepare a smem fragment caching mechanism, to allow for as much caching as allowed on A100 (or f16)

  • make attention tile size processing customizable for backwards pass

  • move atomic add to overloaded function inside mma

  • flexible which type is used for accumulation

  • test out 64x96 tiles on f16

  • bring in a CPU memory efficient version (only for inference, as training does not make sense) using just plain pytorch code

  • figure out how to dispatch differently for architectures (say A100), in case backwards can make use of the increase in shared memory differently

  • decouple row and column sizes for attention tiles

  • dk and dv are now in f16 when it can be (non single headed kv)

  • support more standard head dimensions (wip)

  • debug and fix bias backwards gradients yet again for head size of 32

  • fix attention bias gradients

  • allow for single-headed key / values, as in PaLM

  • fix atomic add for f16

  • attention bias should be able to accept dimensions of an extra batch dimension, for Alphafold2 like attention biasing

  • automate cache-busting of kernel using version as suffix to package name

  • resolve f16 causal numerical issues

  • adopt all learnings from forward kernel to backwards kernel and make sure it outperforms at least on A100

Description

So far cosine similarity attention is not widely used in industry. The only large model that has been trained with it so far is SwinV2. If anyone can invalidate the approach, please open an issue or send me an email. You can run experiments against regular attention using the x-transformers repository.

Update: Boris Dayma has graciously kicked off an experiment (blue with red as baseline) to validate cosine similarity attention with a fixed scale of 10 in a real-world model setting. šŸ™

Update 2: Cosine similarity attention has been proven out in a real-world text-to-image attention network, using a constant scale of 10. No worse than regular attention. Credit goes to Boris Dayma for investing the time to run the experiment and removing doubts surrounding the technique.

Update 3: Robin Rombach has tested out the kernel in this repository with head size of 64 and fixed scale of 10 in a text-to-image model, observing no difference from regular attention. More evaluations pending.

Update 4: The improvement in performance seen in Boris' experiments are likely due to the fact that cosine-sim attention allows for one to switch from pre layernorm to post layernorm configuration in the transformers (as the l2norm effectively takes the place of the pre-layernorm). Cosine sim attention will likely yield results the same as regular attention, without any other changes to the transformer.

Testing

For testing output and gradients are equal for non-autoregressive and autoregressive scenarios

$ python setup.py test

Benchmarking

Make sure to first install the CUDA kernel

$ python setup.py install

Then

$ python benchmark.py

For only benchmarking forwards or backwards, append either --only-forwards or --only-backwards flag to the above. To benchmark autoregressive, append --causal

Benchmarks - wip

GTX 2080 Ti

Forward

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 1.05x   kernel: 0.24ms  baseline: 0.23ms
seq_len: 256    slower: 1.27x   kernel: 0.38ms  baseline: 0.30ms
seq_len: 512    slower: 1.28x   kernel: 0.87ms  baseline: 0.68ms
seq_len: 1024   slower: 1.15x   kernel: 2.63ms  baseline: 2.28ms
seq_len: 2048   slower: 0.99x   kernel: 7.99ms  baseline: 8.10ms
seq_len: 4096   slower: 0.88x   kernel: 30.82ms baseline: 34.84ms
seq_len: 8192   slower: 0.00x   kernel: 121.96ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.85x   kernel: 0.20ms  baseline: 0.24ms
seq_len: 256    slower: 0.97x   kernel: 0.24ms  baseline: 0.25ms
seq_len: 512    slower: 1.22x   kernel: 0.43ms  baseline: 0.35ms
seq_len: 1024   slower: 0.95x   kernel: 0.93ms  baseline: 0.98ms
seq_len: 2048   slower: 0.90x   kernel: 3.16ms  baseline: 3.50ms
seq_len: 4096   slower: 0.85x   kernel: 11.06ms baseline: 13.07ms
seq_len: 8192   slower: 0.00x   kernel: 42.61ms baseline: oom

Backwards - still needs work

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 1.07x   kernel: 0.61ms  baseline: 0.57ms
seq_len: 256    slower: 1.40x   kernel: 0.91ms  baseline: 0.65ms
seq_len: 512    slower: 1.70x   kernel: 2.34ms  baseline: 1.38ms
seq_len: 1024   slower: 1.26x   kernel: 5.67ms  baseline: 4.50ms
seq_len: 2048   slower: 1.29x   kernel: 20.60ms baseline: 15.91ms
seq_len: 4096   slower: 1.30x   kernel: 78.93ms baseline: 60.81ms
seq_len: 8192   slower: 0.00x   kernel: 314.51ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.91x   kernel: 0.50ms  baseline: 0.55ms
seq_len: 256    slower: 1.06x   kernel: 0.58ms  baseline: 0.55ms
seq_len: 512    slower: 1.13x   kernel: 0.81ms  baseline: 0.72ms
seq_len: 1024   slower: 0.97x   kernel: 2.09ms  baseline: 2.16ms
seq_len: 2048   slower: 0.96x   kernel: 7.06ms  baseline: 7.35ms
seq_len: 4096   slower: 0.97x   kernel: 26.08ms baseline: 26.84ms
seq_len: 8192   slower: 0.00x   kernel: 101.02ms    baseline: oom

Forward & Backwards - F32 is definitely slower

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 1.05x   kernel: 0.83ms  baseline: 0.79ms
seq_len: 256    slower: 1.34x   kernel: 1.26ms  baseline: 0.95ms
seq_len: 512    slower: 1.44x   kernel: 3.14ms  baseline: 2.18ms
seq_len: 1024   slower: 1.15x   kernel: 7.83ms  baseline: 6.81ms
seq_len: 2048   slower: 1.20x   kernel: 28.83ms baseline: 24.03ms
seq_len: 4096   slower: 1.20x   kernel: 111.13ms    baseline: 92.51ms
seq_len: 8192   slower: 0.00x   kernel: 441.70ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.89x   kernel: 0.68ms  baseline: 0.77ms
seq_len: 256    slower: 1.03x   kernel: 0.80ms  baseline: 0.77ms
seq_len: 512    slower: 1.06x   kernel: 1.16ms  baseline: 1.10ms
seq_len: 1024   slower: 0.93x   kernel: 2.94ms  baseline: 3.16ms
seq_len: 2048   slower: 0.93x   kernel: 10.06ms baseline: 10.87ms
seq_len: 4096   slower: 0.93x   kernel: 37.09ms baseline: 39.96ms
seq_len: 8192   slower: 0.00x   kernel: 143.13ms    baseline: oom

For autoregressive, a clear win python benchmark.py --causal

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.97x   kernel: 0.81ms  baseline: 0.84ms
seq_len: 256    slower: 1.07x   kernel: 1.12ms  baseline: 1.05ms
seq_len: 512    slower: 0.83x   kernel: 2.23ms  baseline: 2.68ms
seq_len: 1024   slower: 0.55x   kernel: 4.83ms  baseline: 8.82ms
seq_len: 2048   slower: 0.49x   kernel: 15.89ms baseline: 32.68ms
seq_len: 4096   slower: 0.46x   kernel: 57.50ms baseline: 126.00ms
seq_len: 8192   slower: 0.00x   kernel: 224.76ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.82x   kernel: 0.69ms  baseline: 0.84ms
seq_len: 256    slower: 0.95x   kernel: 0.79ms  baseline: 0.83ms
seq_len: 512    slower: 0.78x   kernel: 1.06ms  baseline: 1.37ms
seq_len: 1024   slower: 0.50x   kernel: 2.10ms  baseline: 4.24ms
seq_len: 2048   slower: 0.37x   kernel: 5.85ms  baseline: 15.92ms
seq_len: 4096   slower: 0.31x   kernel: 19.80ms baseline: 64.42ms
seq_len: 8192   slower: 0.00x   kernel: 75.25ms baseline: oom

For variable length sequences with masking, also a clear win. Assume on average 25% of tokens masked out python benchmark.py --mask-prob 0.25

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.95x   kernel: 0.84ms  baseline: 0.89ms
seq_len: 256    slower: 1.19x   kernel: 1.28ms  baseline: 1.08ms
seq_len: 512    slower: 1.23x   kernel: 3.19ms  baseline: 2.59ms
seq_len: 1024   slower: 0.92x   kernel: 8.19ms  baseline: 8.88ms
seq_len: 2048   slower: 0.92x   kernel: 30.08ms baseline: 32.57ms
seq_len: 4096   slower: 0.94x   kernel: 123.20ms    baseline: 131.22ms
seq_len: 8192   slower: 0.00x   kernel: 461.77ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.85x   kernel: 0.77ms  baseline: 0.90ms
seq_len: 256    slower: 0.93x   kernel: 0.86ms  baseline: 0.93ms
seq_len: 512    slower: 0.93x   kernel: 1.31ms  baseline: 1.40ms
seq_len: 1024   slower: 0.76x   kernel: 3.31ms  baseline: 4.35ms
seq_len: 2048   slower: 0.71x   kernel: 11.19ms baseline: 15.65ms
seq_len: 4096   slower: 0.70x   kernel: 41.27ms baseline: 59.01ms
seq_len: 8192   slower: 0.00x   kernel: 158.60ms    baseline: oom

A100 40GB (wip)

Thanks goes out to Stability for providing access to A100s for testing. Thanks to Enrico for taking the time to run some benchmarks when I had no access yet.

A100 is still a work in progress. Shared memory is not fully exploited yet. Strangely enough, F32 seems to be doing better than F16

Forwards

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.98x   kernel: 0.29ms  baseline: 0.30ms
seq_len: 256    slower: 1.19x   kernel: 0.35ms  baseline: 0.29ms
seq_len: 512    slower: 0.94x   kernel: 0.52ms  baseline: 0.55ms
seq_len: 1024   slower: 0.75x   kernel: 1.23ms  baseline: 1.65ms
seq_len: 2048   slower: 0.88x   kernel: 4.17ms  baseline: 4.73ms
seq_len: 4096   slower: 0.79x   kernel: 14.53ms baseline: 18.36ms
seq_len: 8192   slower: 0.64x   kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.84x   kernel: 0.24ms  baseline: 0.29ms
seq_len: 256    slower: 1.02x   kernel: 0.29ms  baseline: 0.29ms
seq_len: 512    slower: 1.24x   kernel: 0.36ms  baseline: 0.29ms
seq_len: 1024   slower: 1.48x   kernel: 0.79ms  baseline: 0.54ms
seq_len: 2048   slower: 1.31x   kernel: 2.08ms  baseline: 1.59ms
seq_len: 4096   slower: 1.21x   kernel: 6.89ms  baseline: 5.70ms
seq_len: 8192   slower: 1.07x   kernel: 24.80ms baseline: 23.15ms

Backwards

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.94x   kernel: 0.57ms  baseline: 0.60ms
seq_len: 256    slower: 1.29x   kernel: 0.75ms  baseline: 0.58ms
seq_len: 512    slower: 1.16x   kernel: 1.30ms  baseline: 1.12ms
seq_len: 1024   slower: 0.98x   kernel: 3.14ms  baseline: 3.19ms
seq_len: 2048   slower: 1.05x   kernel: 11.13ms baseline: 10.63ms
seq_len: 4096   slower: 0.98x   kernel: 40.11ms baseline: 40.79ms
seq_len: 8192   slower: 0.97x   kernel: 154.96ms    baseline: 159.70ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.91x   kernel: 0.55ms  baseline: 0.60ms
seq_len: 256    slower: 1.03x   kernel: 0.62ms  baseline: 0.60ms
seq_len: 512    slower: 1.36x   kernel: 0.82ms  baseline: 0.60ms
seq_len: 1024   slower: 1.52x   kernel: 1.52ms  baseline: 1.01ms
seq_len: 2048   slower: 1.37x   kernel: 4.14ms  baseline: 3.03ms
seq_len: 4096   slower: 1.33x   kernel: 14.23ms baseline: 10.71ms
seq_len: 8192   slower: 1.34x   kernel: 53.90ms baseline: 40.28ms

Forwards & Backwards

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.92x   kernel: 0.80ms  baseline: 0.87ms
seq_len: 256    slower: 1.23x   kernel: 1.07ms  baseline: 0.87ms
seq_len: 512    slower: 1.08x   kernel: 1.80ms  baseline: 1.66ms
seq_len: 1024   slower: 0.94x   kernel: 4.33ms  baseline: 4.62ms
seq_len: 2048   slower: 0.99x   kernel: 15.26ms baseline: 15.44ms
seq_len: 4096   slower: 0.93x   kernel: 54.78ms baseline: 59.21ms
seq_len: 8192   slower: 0.91x   kernel: 210.38ms    baseline: 230.97ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.90x   kernel: 0.78ms  baseline: 0.86ms
seq_len: 256    slower: 1.00x   kernel: 0.87ms  baseline: 0.87ms
seq_len: 512    slower: 1.36x   kernel: 1.18ms  baseline: 0.86ms
seq_len: 1024   slower: 1.49x   kernel: 2.31ms  baseline: 1.55ms
seq_len: 2048   slower: 1.33x   kernel: 6.17ms  baseline: 4.63ms
seq_len: 4096   slower: 1.28x   kernel: 21.08ms baseline: 16.44ms
seq_len: 8192   slower: 1.24x   kernel: 78.75ms baseline: 63.45ms

Autoregressive

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.82x   kernel: 0.82ms  baseline: 1.01ms
seq_len: 256    slower: 1.02x   kernel: 1.00ms  baseline: 0.98ms
seq_len: 512    slower: 0.82x   kernel: 1.55ms  baseline: 1.89ms
seq_len: 1024   slower: 0.51x   kernel: 2.79ms  baseline: 5.44ms
seq_len: 2048   slower: 0.45x   kernel: 8.37ms  baseline: 18.67ms
seq_len: 4096   slower: 0.40x   kernel: 29.16ms baseline: 72.97ms
seq_len: 8192   slower: 0.38x   kernel: 108.68ms    baseline: 285.47ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.82x   kernel: 0.81ms  baseline: 0.98ms
seq_len: 256    slower: 0.90x   kernel: 0.88ms  baseline: 0.98ms
seq_len: 512    slower: 1.16x   kernel: 1.13ms  baseline: 0.97ms
seq_len: 1024   slower: 0.80x   kernel: 1.68ms  baseline: 2.10ms
seq_len: 2048   slower: 0.54x   kernel: 3.66ms  baseline: 6.81ms
seq_len: 4096   slower: 0.45x   kernel: 11.43ms baseline: 25.32ms
seq_len: 8192   slower: 0.41x   kernel: 40.58ms baseline: 99.14ms

Variable lengthed sequences (up to 25% tokens masked out)

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.80x   kernel: 0.85ms  baseline: 1.07ms
seq_len: 256    slower: 1.07x   kernel: 1.15ms  baseline: 1.08ms
seq_len: 512    slower: 1.00x   kernel: 1.94ms  baseline: 1.94ms
seq_len: 1024   slower: 0.84x   kernel: 4.64ms  baseline: 5.55ms
seq_len: 2048   slower: 0.84x   kernel: 15.86ms baseline: 18.86ms
seq_len: 4096   slower: 0.76x   kernel: 55.19ms baseline: 72.47ms
seq_len: 8192   slower: 0.75x   kernel: 212.48ms    baseline: 282.71ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.80x   kernel: 0.83ms  baseline: 1.04ms
seq_len: 256    slower: 0.90x   kernel: 0.93ms  baseline: 1.03ms
seq_len: 512    slower: 1.18x   kernel: 1.22ms  baseline: 1.04ms
seq_len: 1024   slower: 1.10x   kernel: 2.40ms  baseline: 2.17ms
seq_len: 2048   slower: 0.89x   kernel: 6.27ms  baseline: 7.06ms
seq_len: 4096   slower: 0.82x   kernel: 21.19ms baseline: 25.95ms
seq_len: 8192   slower: 0.78x   kernel: 79.45ms baseline: 101.83ms

Training a small GPT on Enwik8

$ make train

Try 8192 sequence length. It'll be slow but will work (normal attention will break at > 2048, you'll see this if you remove the --use-cuda-kernel flag)

$ python train.py --seq-len 8192 --use-cuda-kernel

Citations

@article{Dao2022FlashAttentionFA,
    title   = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
    author  = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.14135}
}
@misc{rabe2021selfattention,
    title   = {Self-attention Does Not Need $O(n^2)$ Memory}, 
    author  = {Markus N. Rabe and Charles Staats},
    year    = {2021},
    eprint  = {2112.05682},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{Henry2020QueryKeyNF,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen},
    booktitle = {FINDINGS},
    year    = {2020}
}
@article{Wang2022DeepNetST,
    title   = {DeepNet: Scaling Transformers to 1, 000 Layers},
    author  = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2203.00555}
}

More Repositories

1

vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Python
13,633
star
2

DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
Python
11,068
star
3

imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
Python
7,832
star
4

PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
Python
7,611
star
5

DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
Python
5,132
star
6

deep-daze

Simple command line tool for text to image generation using OpenAI's CLIP and Siren (Implicit neural representation network). Technique was originally created by https://twitter.com/advadnoun
Python
4,387
star
7

denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
Python
3,959
star
8

stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
Python
3,433
star
9

musiclm-pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch
Python
3,048
star
10

x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
Python
2,707
star
11

big-sleep

A simple command line tool for text to image generation, using OpenAI's CLIP and a BigGAN. Technique was originally created by https://twitter.com/advadnoun
Python
2,446
star
12

audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
Python
2,285
star
13

lion-pytorch

šŸ¦ Lion, new optimizer discovered by Google Brain using genetic algorithms that is purportedly better than Adam(w), in Pytorch
Python
1,933
star
14

toolformer-pytorch

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
Python
1,905
star
15

reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
Python
1,870
star
16

make-a-video-pytorch

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch
Python
1,853
star
17

gigagan-pytorch

Implementation of GigaGAN, new SOTA GAN out of Adobe. Culmination of nearly a decade of research into GANs
Python
1,632
star
18

alphafold2

To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released
Python
1,536
star
19

lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
Python
1,526
star
20

lambda-networks

Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute
Python
1,516
star
21

byol-pytorch

Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch
Python
1,497
star
22

self-rewarding-lm-pytorch

Implementation of the training framework proposed in Self-Rewarding Language Model, from MetaAI
Python
1,253
star
23

naturalspeech2-pytorch

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch
Python
1,214
star
24

flamingo-pytorch

Implementation of šŸ¦© Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch
Python
1,155
star
25

video-diffusion-pytorch

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch
Python
1,141
star
26

soundstorm-pytorch

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch
Python
1,130
star
27

CoCa-pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch
Python
990
star
28

performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
Python
937
star
29

perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
Python
935
star
30

RETRO-pytorch

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch
Python
835
star
31

mlp-mixer-pytorch

An All-MLP solution for Vision, from Google AI
Python
833
star
32

muse-maskgit-pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
Python
821
star
33

PaLM-pytorch

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways
Python
812
star
34

vector-quantize-pytorch

Vector Quantization, in Pytorch
Python
810
star
35

phenaki-pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch
Python
724
star
36

x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers
Python
658
star
37

bottleneck-transformer-pytorch

Implementation of Bottleneck Transformer in Pytorch
Python
632
star
38

memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
Python
614
star
39

TimeSformer-pytorch

Implementation of TimeSformer from Facebook AI, a pure attention-based solution for video classification
Python
613
star
40

MEGABYTE-pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
Python
594
star
41

meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
Python
564
star
42

nuwa-pytorch

Implementation of NƜWA, state of the art attention network for text to video synthesis, in Pytorch
Python
531
star
43

voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
Python
521
star
44

point-transformer-pytorch

Implementation of the Point Transformer layer, in Pytorch
Python
518
star
45

parti-pytorch

Implementation of Parti, Google's pure attention-based text-to-image neural network, in Pytorch
Python
509
star
46

tab-transformer-pytorch

Implementation of TabTransformer, attention network for tabular data, in Pytorch
Python
485
star
47

alphafold3-pytorch

Implementation of Alphafold 3 in Pytorch
Python
483
star
48

linear-attention-transformer

Transformer based on a variant of attention that is linear complexity in respect to sequence length
Python
468
star
49

magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
Python
436
star
50

ema-pytorch

A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
Python
408
star
51

egnn-pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch
Python
400
star
52

g-mlp-pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Python
391
star
53

recurrent-memory-transformer-pytorch

Implementation of Recurrent Memory Transformer, Neurips 2022 paper, in Pytorch
Python
384
star
54

ring-attention-pytorch

Implementation of šŸ’ Ring Attention, from Liu et al. at Berkeley AI, in Pytorch
Python
380
star
55

siren-pytorch

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
Python
377
star
56

enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
Python
352
star
57

iTransformer

Unofficial implementation of iTransformer - SOTA Time Series Forecasting using Attention networks, out of Tsinghua / Ant group
Python
349
star
58

robotic-transformer-pytorch

Implementation of RT1 (Robotic Transformer) in Pytorch
Python
346
star
59

memory-efficient-attention-pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(nĀ²) Memory"
Python
342
star
60

FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
Python
334
star
61

bit-diffusion

Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch
Python
313
star
62

medical-chatgpt

Implementation of ChatGPT, but tailored towards primary care medicine, with the reward being able to collect patient histories in a thorough and efficient manner and come up with a reasonable differential diagnosis
Python
311
star
63

slot-attention

Implementation of Slot Attention from GoogleAI
Python
303
star
64

q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
Python
293
star
65

BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
Python
289
star
66

classifier-free-guidance-pytorch

Implementation of Classifier Free Guidance in Pytorch, with emphasis on text conditioning, and flexibility to include multiple text embedding models
Python
282
star
67

transformer-in-transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Python
277
star
68

axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
Python
273
star
69

conformer

Implementation of the convolutional module from the Conformer paper, for use in Transformers
Python
272
star
70

mixture-of-experts

A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
Python
264
star
71

deformable-attention

Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"
Python
258
star
72

magic3d-pytorch

Implementation of Magic3D, Text to 3D content synthesis, in Pytorch
Python
258
star
73

x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings
Python
252
star
74

routing-transformer

Fully featured implementation of Routing Transformer
Python
251
star
75

Adan-pytorch

Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch
Python
245
star
76

spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
Python
241
star
77

st-moe-pytorch

Implementation of ST-Moe, the latest incarnation of MoE after years of research at Brain, in Pytorch
Python
237
star
78

perfusion-pytorch

Implementation of Key-Locked Rank One Editing, from Nvidia AI
Python
229
star
79

equiformer-pytorch

Implementation of the Equiformer, SE3/E3 equivariant attention network that reaches new SOTA, and adopted for use by EquiFold for protein folding
Python
227
star
80

segformer-pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Python
227
star
81

sinkhorn-transformer

Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention
Python
222
star
82

pixel-level-contrastive-learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch
Python
220
star
83

lumiere-pytorch

Implementation of Lumiere, SOTA text-to-video generation from Google Deepmind, in Pytorch
Python
216
star
84

local-attention

An implementation of local windowed attention for language modeling
Python
216
star
85

CoLT5-attention

Implementation of the conditionally routed attention in the CoLT5 architecture, in Pytorch
Python
216
star
86

natural-speech-pytorch

Implementation of the neural network proposed in Natural Speech, a text-to-speech generator that is indistinguishable from human recordings for the first time, from Microsoft Research
Python
215
star
87

soft-moe-pytorch

Implementation of Soft MoE, proposed by Brain's Vision team, in Pytorch
Python
211
star
88

se3-transformer-pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
Python
211
star
89

block-recurrent-transformer-pytorch

Implementation of Block Recurrent Transformer - Pytorch
Python
205
star
90

Mega-pytorch

Implementation of Mega, the Single-head Attention with Multi-headed EMA architecture that currently holds SOTA on Long Range Arena
Python
201
star
91

simple-hierarchical-transformer

Experiments around a simple idea for inducing multiple hierarchical predictive model within a GPT
Python
198
star
92

med-seg-diff-pytorch

Implementation of MedSegDiff in Pytorch - SOTA medical segmentation using DDPM and filtering of features in fourier space
Python
195
star
93

triton-transformer

Implementation of a Transformer, but completely in Triton
Python
195
star
94

jax2torch

Use Jax functions in Pytorch
Python
194
star
95

halonet-pytorch

Implementation of the šŸ˜‡ Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Python
193
star
96

attention

This repository will house a visualization that will attempt to convey instant enlightenment of how Attention works to someone not working in artificial intelligence, with 3Blue1Brown as inspiration
HTML
189
star
97

recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
Python
188
star
98

electra-pytorch

A simple and working implementation of Electra, the fastest way to pretrain language models from scratch, in Pytorch
Python
186
star
99

PaLM-jax

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
Python
184
star
100

unet-stylegan2

A Pytorch implementation of Stylegan2 with UNet Discriminator
Python
182
star