• Stars
    star
    1,174
  • Rank 39,823 (Top 0.8 %)
  • Language
    Python
  • License
    MIT License
  • Created 11 months ago
  • Updated 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton

flash-linear-attention

This repo aims at providing efficient Triton-based implementations of state-of-the-art linear attention models.

Join discord if you are interested in this project or have any questions!

Models

Date Title Paper Code FLA impl
2023-07 [RetNet] Retentive network: a successor to transformer for large language models(@MRSA@THU) [arxiv] [official] [RetNet] code
2023-12 [GLA] Gated Linear Attention Transformers with Hardware-Efficient Training (@MIT@IBM) [arxiv] [official] code
2023-12 [Based] An Educational and Effective Sequence Mixer (@Stanford Hazyresearch) [blog] [official] code
2023-09 [Hedgehog] The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry(@HazyResearch) openreview code
2023-10 [PolySketchFormer] Fast Transformers via Sketching Polynomial Kernels (@CMU@Google) arxiv TODO
2023-07 [TransnormerLLM] A Faster and Better Large Language Model with Improved TransNormer (@Shanghai AI Lab) openreview arxiv [official] [Lightning2] TODO
2023-05 [RWKV-v6] Reinventing RNNs for the Transformer Era (@BlinkDL) arxiv [official] TODO
2023-10 [GateLoop]Fully Data-Controlled Linear Recurrence for Sequence Modeling openreview arxiv [official] [jax] TODO
2021-10 [ABC] Attention with Bounded-memory Control (@UW) arxiv TODO
2023-09 [VQ-transformer] Linear-Time Transformers via Vector Quantization arxiv [official] TODO

Installation

The following requirements should be satisfied

As fla is actively developed now, no released packages are provided at this time. If you do need to use fla ops/modules and contemplate further explorations, an alternative way is to install the package from source

pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

or manage fla with submodules

git submodule add https://github.com/sustcsonglin/flash-linear-attention.git 3rdparty/flash-linear-attention
ln -s 3rdparty/flash-linear-attention/fla fla

⚠️ Caveats on numerical stability!!

If you're not working with Triton v2.2 or its nightly release, it's important to be aware of potential issues with the FusedChunk implementation, detailed in this issue. You can run the test python tests/test_fused_chunk.py to check if your version is affected by similar compiler problems. While we offer some fixes for Triton<=2.1, be aware that these may result in reduced performance.

For both Triton 2.2 and earlier versions (up to 2.1), you can reliably use the Chunk version (with hidden states materialized into HBMs). After careful optimization, this version generally delivers high performance in most scenarios.

Usage

We provide "token mixing" linear attention layers in fla.layers for you to use. You can replace the standard multihead attention layer in your transformer with the other linear attention layers. Example usage is as follows:

from fla.layers import MultiScaleRetention, GatedLinearAttention, BasedLinearAttention 

d_model = 1024
num_head = 4
device = "cuda:0"
dtype = torch.bfloat16

retnet = MultiScaleRetention(d_model=d_model, num_heads=num_head).to(device).to(dtype)
gla = GatedLinearAttention(d_model=d_model, num_heads=num_head).to(device).to(dtype)
based = BasedLinearAttention(d_model=d_model, num_heads=num_head).to(device).to(dtype)

bsz, seq_len, d_model = 32, 2048, 1024
x = torch.randn(bsz, seq_len, d_model).to(device).to(dtype)
y1 = retnet(x)
y2 = gla(x)
y3 = based(x)

assert y1.shape == y2.shape == y3.shape == x.shape

Benchmarks

We compared our Triton-based RetNet implementation with CUDA-based FlashAttention2, using a batch size of 8, 32 heads, and a head dimension of 128, across different sequence lengths. These tests were conducted on a single A100 80GB GPU, as illustrated in the following graph

# you might have to first install `fla` to enable its import via `pip install -e .`
$ python benchmark_retention.py
Performance:
   seq_len  fused_chunk_fwd  chunk_fwd  parallel_fwd  fused_chunk_fwdbwd  chunk_fwdbwd  parallel_fwdbwd  flash_fwd  flash_fwdbwd
0    128.0         0.093184   0.185344      0.067584            1.009664      1.591296         1.044480   0.041984      0.282624
1    256.0         0.165888   0.219136      0.126976            1.024000      1.596928         1.073152   0.074752      0.413696
2    512.0         0.308224   0.397312      0.265216            1.550336      1.603584         1.301504   0.156672      0.883712
3   1024.0         0.603136   0.747520      0.706560            3.044864      3.089408         3.529728   0.467968      2.342912
4   2048.0         1.191424   1.403904      2.141184            6.010880      6.059008        11.009024   1.612800      7.135232
5   4096.0         2.377728   2.755072      7.392256           11.932672     11.938816        37.792770   5.997568     24.435200
6   8192.0         4.750336   5.491712     26.402817           23.759359     23.952385       141.014023  22.682114     90.619904
7  16384.0         9.591296  10.870784    101.262337           47.666176     48.745472       539.853821  91.346947    346.318848

Performance

Different forms of linear attention

Please refer to Sectiton 2.3 of GLA paper for hardware considerations of different forms of linear attention.

  • Parallel: Self-attention-styled computation in $O(L^2)$ time with sequence parallelism.
  • FusedRecurrent: Recurrent computation in $O(L)$ time. Hidden states are computed on-the-fly in shared memory without any materialization to global memory (see Algorithm1 of this paper for more details!). This saves a lot of I/O cost and should be a strong baseline for speed comparison.
  • FusedChunk: Chunkwise computation in $O(LC)$ time where $C$ is the chunk size. Hidden states are computed on-the-fly without any materialization to global memory likewise FusedRecurrent. This version is usually better than FusedReuccurent because tensor cores can be used for sequence level "reduction", whilst FusedRecurrent cannot use tensor cores at all. Note that there is no sequence level parallelism in this implementation, so this impl is not suitable for the very small batch size setting. Should be more memory efficient than ParallelChunk.
  • ParallelChunk: Chunkwise computation with sequence parallelism. Need to materialize hidden states to global memory for each chunk. $C$ is needed to set properly to achieve good performance because when $C$ is small there are too many hidden states to load/store to global memory; and when $C$ is too large the FLOPs are high. Recommened $C$ is [64, 128, 256]

Citation

If you find this repo useful, please consider citing our works:

@article{yang2023gated,
  title   = {Gated Linear Attention Transformers with Hardware-Efficient Training},
  author  = {Yang, Songlin and Wang, Bailin and Shen, Yikang and Panda, Rameswar and Kim, Yoon},
  journal = {arXiv preprint arXiv:2312.06635},
  year    = {2023}
}

@software{yang2024fla,
  title  = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism},
  author = {Yang, Songlin and Zhang, Yu},
  url    = {https://github.com/sustcsonglin/flash-linear-attention},
  month  = jan,
  year   = {2024}
}