• Stars
    star
    276
  • Rank 149,319 (Top 3 %)
  • Language
    C++
  • License
    Apache License 2.0
  • Created about 1 year ago
  • Updated about 1 month ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores

This repository contains the official code for FlashFFTConv, a fast algorithm for computing long depthwise convolutions using the FFT algorithm.

FlashFFTConv computes convolutions up to 7.93 times faster than PyTorch FFT convolutions, with up to 8.21 times less memory usage. FlashFFTConv supports convolution kernel lengths up to 4,194,304.

We also provide a fast kernel for short 1D depthwise convolutions (e.g., where the kernel length is on the order of 3/5), which runs 7 times faster than PyTorch Conv1D. This module is useful for additional speedup for language models like Monarch Mixer, H3, and Hyena.

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
Daniel Y. Fu*, Hermann Kumbong*, Eric Nguyen, Christopher Ré
Paper: https://arxiv.org/abs/2311.05908
Blog: https://hazyresearch.stanford.edu/blog/2023-11-13-flashfftconv

FlashFFTConv

FlashFFTConv logo

Examples and Usage

We've been happy to use FlashFFTConv to support various projects, including Monarch Mixer, Hyena/H3, HyenaDNA, and more. FlashFFTConv is also being used to train various new models that haven't been released yet - we'll be updating this README with pointers as they're publicly announced.

Check out the examples folder for end-to-end examples of how to use FlashFFTConv in your models.

You can also run a standalone CIFAR example to see usage as soon as the package is installed:

python standalone_cifar.py

Installation

Requirements: We recommend using the Nvidia PyTorch docker container. We've tested and developed this library on version 23.05.

  • PyTorch 2.0 Required
  • We have tested with CUDA version 12.1 and toolkit version 12.1
  • We have tested this on A100 and H100, but it should work on any Ampere/Hopper architecture (3090, 4090, etc)

To check your CUDA version:

  • Run nvcc --version and check the version number of your CUDA toolkit. Our Docker ships with version 12.1.
  • Run nvidia-smi to check the version of your CUDA drivers. Our Docker ships with version 12.1.

You can install via pip:

pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
pip install git+https://github.com/HazyResearch/flash-fft-conv.git

Or from source:

git clone https://github.com/HazyResearch/flash-fft-conv.git

cd flash-fft-conv

cd csrc/flashfftconv
python setup.py install

cd ../..

python setup.py install

Once it's installed, you should be able to run the test suite:

pytest -s -q tests/test_flashfftconv.py

This test should run on machines with 40GB of GPU memory.

Short Depthwise Kernel

The short depthwise kernel is also installed with these commands. You can run this like this:

pytest -s -q tests/test_conv1d.py

How to Use FlashFFTConv

The flashfftconv package contains a PyTorch interface called FlashFFTConv, which is initialized with a particular FFT size.

This module computes an FFT convolution (iFFT(FFT(u) * FFT(k))).

from flashfftconv import FlashFFTConv

Usage:

from flashfftconv import FlashFFTConv

# size of the FFT
my_flashfftconv = FlashFFTConv(32768, dtype=torch.bfloat16) # generally more stable!

# B is batch size, H is model dimension, L is sequence length
B = 16
H = 768
# input can be smaller than FFT size, but needs to be divisible by 2
L = 16384

# the input, B H L
x = torch.randn(B, H, L, dtype=torch.bfloat16) # same type as the input
k = torch.randn(H, L, dtype=torch.float32) # kernel needs to be fp32 for now

out = my_flashfftconv(x, k)

Example Model

We recommend creating one FlashFFTConv object per model, and reusing it between layers.

For example:

import torch
from flashfftconv import FlashFFTConv

def MyModel(torch.nn.Module):
    def __init__(self, H, seqlen, num_layers):
        super().__init__()

        self.H = H
        self.seqlen = seqlen
        self.num_layers = num_layers
        self.flashfftconv = FlashFFTConv(seqlen, dtype=torch.bfloat16)

        # create your conv layers
        self.long_conv_layers = torch.nn.ModuleList([
            ConvLayer(H, seqlen)
            for i in range(num_layers)
        ])

        # add a pointer to the flashfft object in each layer
        for layer in self.long_conv_layers:
            layer.flashfftconv = self.flashfftconv

        ...
    
    def forward(self, x):
        for layer in self.long_conv_layers:
            x = layer(x)

        return x

def ConvLayer(torch.nn.Module):
    def __init__(self, H, seqlen):
        self.k = torch.nn.Parameter(torch.randn(H, seqlen, dtype=torch.float32))
        ...

    def forward(self, x):
        return self.flashfftconv(x, self.k) # self.flashfftconv comes from the wrapper model!

Gating and Implicit Padding

A common use case for long FFT convolutions is for language modeling. These architectures often use gated convolutions and pad the inputs with zeros to ensure causality.

For example, a gated causal convolution might look like this in PyTorch:

def gated_conv(u, k, in_gate, out_gate):
    # u, in_gate, and out_gate have shape B, H, L
    # k has shape H, L
    B, H, L = u.shape
    fft_size = 2 * L
    padding = fft_size - L
    u = torch.nn.functional.pad(u, (0, padding))
    k = torch.nn.functional.pad(k, (0, padding))
    in_gate = torch.nn.functional.pad(in_gate, (0, padding))
    out_gate = torch.nn.functional.pad(out_gate, (0, padding))

    # compute the gated convolution
    u_f = torch.fft.fft(u * in_gate, dim=-1)
    k_f = torch.fft.fft(k, dim=-1)
    y_f = u_f * k_f
    y = torch.fft.ifft(y_f, dim=-1).real * out_gate
    return y

Each of these padding operations, as well as the gating operations, incur expensive memory I/Os, which slows down the model.

FlashFFTConv supports implicit padding and gating without the need for extra I/O:

L = ... # get L from somewhere
flashfftconv = FlashFFTConv(2 * L, dtype=torch.bfloat16) # bf16 is usually necessary for gating
y = flashfftconv(u, k, in_gate, out_gate)

Short Depthwise Convolutions

For short, depthwise convolutions (groups = dimension in PyTorch Conv1D), you can run them like this:

from flashfftconv import FlashDepthwiseConv1d

# set up PyTorch equivalent to get the weights
# in_channels = out_channels, and kernel size must be odd
conv1d_torch = nn.Conv1d(
    in_channels = d,
    out_channels = d,
    kernel_size = k,
    groups = d,
    padding = padding,
    dtype = dtype
)

flash_conv1d = FlashDepthWiseConv1d(
    channels = d,
    kernel_size=k,
    padding=padding,
    weights=conv1d_torch.weight,
    bias=conv1d_torch.bias,
    dtype = dtype # this should be the dtype of the weights
)

out_torch = conv1d_torch(x) # x is B, d, L
out_flash = flash_conv1d(x) # x can be a different dtype than weights

# out_torch and out_flash should be the same!

To support mixed precision training, FlashDepthWiseConv1d supports using fp32 weights with fp16 inputs (or fp32 inputs). Currently the bf16 backward pass has a bug, but the forward pass is supported.

Benchmarking

FlashFFTConv Benchmarks

To run FlashFFTConv benchmarks, install the module and run python benchmarks/benchmark_flashfftconv.py.

These are the runtimes we see for a gated convolution for various sequence lengths, on one H100-SXM. All results scaled to batch size 64, hidden dimension 768.

Sequence Length 256 1K 4K 8K 16K 32K 1M 2M 4M
PyTorch 0.62 2.30 9.49 19.4 29.9 84.8 3,071.4 6,342.6 13,031.2
FlashFFTConv 0.11 0.29 1.43 3.58 12.2 26.3 1,768.9 4,623.5 10,049.4
Speedup 5.64× 7.93× 6.64× 5.42× 2.45× 3.22× 1.74× 1.37× 1.30×
Memory Savings 6.65× 6.40× 6.35× 6.34× 6.17× 5.87× 2.82× 2.81× 2.81×

Please see our paper for more benchmarks!

Short Depthwise Convolution Benchmarks

To benchmark short depthwise convolutions, install the module and run python benchmarks/benchmark_conv1d.py.

Here are some results for BLH input on H100:

B L D K torch time (ms) cudatime (ms) speedup
16 1024 768 5 0.19 0.03 5.50
16 1024 1024 5 0.25 0.04 6.00
16 1024 2048 5 0.50 0.08 6.50
16 1024 8192 5 2.08 0.29 7.21
16 2048 768 5 0.37 0.06 5.91
16 2048 1024 5 0.50 0.08 6.33
16 2048 2048 5 1.00 0.15 6.77
16 2048 8192 5 4.17 0.57 7.31
16 4096 768 5 0.74 0.12 6.17
16 4096 1024 5 0.99 0.15 6.56
16 4096 2048 5 2.03 0.29 7.04
16 4096 8192 5 8.25 1.14 7.27
16 8192 768 5 1.49 0.23 6.36
16 8192 1024 5 2.01 0.30 6.80
16 8192 2048 5 4.10 0.57 7.18
16 8192 8192 5 16.42 2.26 7.26

We also support BHL input, but it's a bit slower (still optimizing!).

Input Requirements and Notes

Currently, we have a few requirements on the inputs to the interface to FlashFFTConv:

  • We assume that the input u has shape (B, H, L), and the kernel k has shape (H, L).
  • These inputs must be contiguous in GPU memory (u.is_contiguous() should be True).
  • The FFT size (seqlen that FlashFFTConv is initialized with) must be a power of two between 256 and 4,194,304.
  • For FFT sizes larger than 32,768, H must be a multiple of 16.
  • L can be smaller than FFT size but must be divisible by 2. For FFT sizes 512 and 2048, L must be divisible by 4.
  • We only support FP16 and BF16 for now. FP16 is faster, but we generally find BF16 more stable during training.
  • For short depthwise convs, we only support FP16 for now, and the kernel size has to be odd.

Citation

This work builds on a line of work studying how to make FFT convolutions efficient on GPUs:

If you use this codebase or otherwise found the ideas useful, please reach out to let us know - we love hearing about how our work is being used! You can reach Dan at [email protected].

You can also cite our work:

@article{fu2023flashfftconv,
  title={Flash{FFTC}onv: Efficient Convolutions for Long Sequences with Tensor Cores},
  author={Fu, Daniel Y. and Kumbong, Hermann and Nguyen, Eric and R{\'e}, Christopher},
  booktitle={arXiv preprint arXiv:2311.05908},
  year={2023}
}

@inproceedings{fu2023monarch,
  title={Monarch {M}ixer: A Simple Sub-Quadratic GEMM-Based Architecture},
  author={Fu, Daniel Y. and Arora, Simran and Grogan, Jessica and Johnson, Isys and Eyuboglu, Sabri and Thomas, Armin W and Spector, Benjamin and Poli, Michael and Rudra, Atri and R{\'e}, Christopher},
  booktitle={Advances in Neural Information Processing Systems},
  year={2023}
}

@inproceedings{fu2023simple,
  title={Simple Hardware-Efficient Long Convolutions for Sequence Modeling},
  author={Fu, Daniel Y. and Epstein, Elliot L. and Nguyen, Eric and Thomas, Armin W. and Zhang, Michael and Dao, Tri and Rudra, Atri and R{\'e}, Christopher},
  journal={International Conference on Machine Learning},
  year={2023}
}

@inproceedings{fu2023hungry,
  title={Hungry {H}ungry {H}ippos: Towards Language Modeling with State Space Models},
  author={Fu, Daniel Y. and Dao, Tri and Saab, Khaled K. and Thomas, Armin W.
  and Rudra, Atri and R{\'e}, Christopher},
  booktitle={International Conference on Learning Representations},
  year={2023}
}

More Repositories

1

flash-attention

Fast and memory-efficient exact attention
Python
3,673
star
2

deepdive

DeepDive
Shell
1,957
star
3

ThunderKittens

Tile primitives for speedy kernels
Cuda
1,555
star
4

state-spaces

Sequence Modeling with Structured State Spaces
Jupyter Notebook
1,372
star
5

data-centric-ai

Resources for Data Centric AI
TeX
1,099
star
6

safari

Convolutions for Sequence Modeling
Assembly
867
star
7

meerkat

Creative interactive views of any dataset.
Python
826
star
8

hgcn

Hyperbolic Graph Convolutional Networks in PyTorch.
Python
597
star
9

hyena-dna

Official implementation for HyenaDNA, a long-range genomic foundation model built with Hyena
Assembly
585
star
10

ama_prompting

Ask Me Anything language model prompting
Python
538
star
11

m2

Repo for "Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture"
Assembly
535
star
12

H3

Language Modeling with the H3 State Space Model
Assembly
513
star
13

evaporate

This repo contains data and code for the paper "Language Models Enable Simple Systems for Generating Structured Views of Heterogeneous Data Lakes"
Python
479
star
14

manifest

Prompt programming with FMs.
Python
440
star
15

pdftotree

🌲 A tool for converting PDF into hOCR with text, tables, and figures being recognized and preserved.
Python
431
star
16

metal

Snorkel MeTaL: A framework for training models with multi-task weak supervision
Python
423
star
17

fonduer

A knowledge base construction engine for richly formatted data
Python
408
star
18

aisys-building-blocks

Building blocks for foundation models.
377
star
19

hyperbolics

Hyperbolic Embeddings
Python
372
star
20

legalbench

An open science effort to benchmark legal reasoning in foundation models
Python
341
star
21

flyingsquid

More interactive weak supervision with FlyingSquid
Python
313
star
22

KGEmb

Hyperbolic Knowledge Graph embeddings.
Python
249
star
23

bootleg

Self-Supervision for Named Entity Disambiguation at the Tail
Python
213
star
24

based

Code for exploring Based models from "Simple linear attention language models balance the recall-throughput tradeoff"
Python
209
star
25

HypHC

Hyperbolic Hierarchical Clustering.
Python
192
star
26

fly

Python
191
star
27

TART

TART: A plug-and-play Transformer module for task-agnostic reasoning
Python
190
star
28

tanda

Learning to Compose Domain-Specific Transformations for Data Augmentation
Python
171
star
29

hippo-code

Python
166
star
30

butterfly

Butterfly matrix multiplication in PyTorch
Python
164
star
31

spacetime

Code for SpaceTime 🌌⏱️. Proposed in Effectively Modeling Time Series with Simple Discrete State Spaces, ICLR 2023.
Python
163
star
32

zoology

Understand and test language model architectures on synthetic tasks.
Python
160
star
33

lolcats

Repo for "LoLCATs: On Low-Rank Linearizing of Large Language Models"
Python
154
star
34

babble

A system for generating training labels via natural language explanations
Python
146
star
35

EmptyHeaded

Your worst case is our best case.
C++
138
star
36

domino

Python
134
star
37

blocking-tutorial

C++
132
star
38

mindbender

Tools for iterative knowledge base development with DeepDive
CoffeeScript
117
star
39

reef

Automatically labeling training data
Jupyter Notebook
106
star
40

fm_data_tasks

Foundation Models for Data Tasks
Python
100
star
41

fonduer-tutorials

A collection of simple tutorials for using Fonduer
Jupyter Notebook
100
star
42

eclair-agents

Automating enterprise workflows with multimodal agents
Jupyter Notebook
92
star
43

TreeStructure

Table Extraction Tool
Jupyter Notebook
90
star
44

CaffeConTroll

C++
76
star
45

epoxy

Interactive Model Iteration with Weak Supervision and Pre-Trained Embeddings
Python
76
star
46

HoroPCA

Hyperbolic PCA via Horospherical Projections
Python
68
star
47

structured-nets

Structured matrices for compressing neural networks
Python
66
star
48

hidden-stratification

Combating hidden stratification with GEORGE
Jupyter Notebook
62
star
49

numbskull

Numba-based version of DimmWitted Gibbs sampler
Python
46
star
50

prefix-linear-attention

Python
44
star
51

model-patching

Model Patching: Closing the Subgroup Performance Gap with Data Augmentation
Python
42
star
52

skill-it

Skill-It! A Data-Driven Skills Framework for Understanding and Training Language Models
Jupyter Notebook
41
star
53

cs145-notebooks-2016

Public materials for the Fall 2016 offering of CS145
Jupyter Notebook
35
star
54

mandoline

(ICML 2021) Mandoline: Model Evaluation under Distribution Shift
Python
31
star
55

mongoose

A Learnable LSH Framework for Efficient NN Training
Python
30
star
56

thanos-code

Code release for the paper Perfectly Balanced: Improving Transfer and Robustness of Supervised Contrastive Learning
Python
28
star
57

ukb-cardiac-mri

Weakly Supervised MRI Series Classification for the UK Biobank
Python
25
star
58

tuffy

Tuffy, a Markov Logic Network solver
Java
24
star
59

snorkel-superglue

Applying Snorkel to SuperGLUE
Jupyter Notebook
23
star
60

correct-n-contrast

Official code repository for Correct-N-Contrast
Python
21
star
61

ludwig-benchmarking-toolkit

Ludwig benchmark
Python
19
star
62

smallfry

Python
19
star
63

tabi

Code release for Type-Aware Bi-Encoders for Open-Domain Entity Retrieval
Python
19
star
64

lp_rffs

Low precision random Fourier features for kernel approximation
Python
19
star
65

ddlog

Compiler for writing DeepDive applications in a Datalog-like language — ⚠️🚧🛑 REPO MOVED TO DEEPDIVE 👇🏿
Scala
19
star
66

wonderbread

WONDERBREAD benchmark + dataset for BPM tasks
Jupyter Notebook
19
star
67

augmentation_code

Reproducible code for Augmentation paper
Python
18
star
68

sampler

DimmWitted Gibbs Sampler in C++ — ⚠️🚧🛑 REPO MOVED TO DEEPDIVE 👉🏿
C++
17
star
69

random_embedding

Python
16
star
70

snorkel-biocorpus

Python
16
star
71

ddbiolib

DeepDive Biomedical Tools
Python
15
star
72

bazaar

JavaScript
14
star
73

Omnivore

Omnivore Optimizer and Distributed CcT
C++
13
star
74

anchor-stability

A study of the downstream instability of word embeddings
Jupyter Notebook
12
star
75

medical-ned-integration

Cross-domain data integration for named entity disambiguation in biomedical text
Python
11
star
76

dd-genomics

The Genomics DeepDive project
Python
11
star
77

embroid

Embroid: Unsupervised Prediction Smoothing Can Improve Few-Shot Classification
Jupyter Notebook
11
star
78

torchhalp

Python
10
star
79

dimmwitted

C++
10
star
80

Accelerated-PCA

Accelerated Stochastic Power Iteration with Momentum
Jupyter Notebook
9
star
81

liger

Liger: Fusing Weak Supervision and Model Embeddings
Python
9
star
82

cross-modal-ws-demo

HTML
9
star
83

hyperE

HTML
8
star
84

treedlib

Jupyter Notebook
8
star
85

ivy-tutorial

An Introductory Tutorial for Ivy
Jupyter Notebook
7
star
86

observational

Observational Supervision for Medical Image Classification using Gaze Data
Jupyter Notebook
7
star
87

chinstrap

C++
6
star
88

quadrature-features

Code to generate kernel features using Gaussian quadrature
Python
6
star
89

icij-maude

Weakly supervised classification of adverse event reports from the FDA's MAUDE database.
Python
6
star
90

librarian

DeepDive Librarian for managing all data sets we publish and receive
Python
3
star
91

halp

Python
3
star
92

bert-pretraining

Python
2
star
93

d3m-model-search

D3M Model Search Component
Python
2
star
94

elementary

Data services and APIs
Python
1
star
95

dependency_model

Structure learning code from [ICML'19 paper](https://arxiv.org/abs/1903.05844)
Python
1
star