• Stars
    star
    211
  • Rank 186,867 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 4 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.

SE3 Transformer - Pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.

Open In Colab Example of equivariance

If you had been using any version of SE3 Transformers prior to version 0.6.0, please update. A huge bug has been uncovered by @MattMcPartlon, if you were not using the adjacency sparse neighbors settings and relying on nearest neighbors functionality

Update: It is recommended that you use Equiformer instead

Install

$ pip install se3-transformer-pytorch

Usage

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Potential example usage in Alphafold2, as outlined here

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True,
    differentiable_coors = True
)

atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)

You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,       # 28 unique atoms
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(atoms, coors, mask, return_type = 1) # (2, 32, 3)

If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 2,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True  # reduce out the final dimension
)

atom_feats  = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1

# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates

Edges

To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)

If you would like to pass in continuous values for your edges, you can choose to not set the num_edge_tokens, encode your discrete bond types, and then concat it to the fourier features of these continuous values

import torch
from se3_transformer_pytorch import SE3Transformer
from se3_transformer_pytorch.utils import fourier_encode

model = SE3Transformer(
    dim = 64,
    depth = 1,
    attend_self = True,
    num_degrees = 2,
    output_degrees = 2,
    edge_dim = 34           # edge dimension must match the final dimension of the edges being passed in
)

feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()

pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2))  # say there are 2

edges = fourier_encode(
    pairwise_continuous_values,
    num_encodings = 8,
    include_self = True
) # (1, 32, 32, 34) - {2 * (2 * 8 + 1)}

out = model(feats, coors, mask, edges = edges, return_type = 1)

Sparse Neighbors

If you know the connectivity of your points (say you are working with molecules), you can pass in an adjacency matrix, in the form of a boolean mask (where True indicates connectivity).

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 32,
    heads = 8,
    depth = 1,
    dim_head = 64,
    num_degrees = 2,
    valid_radius = 10,
    attend_sparse_neighbors = True,  # this must be set to true, in which case it will assert that you pass in the adjacency matrix
    num_neighbors = 0,               # if you set this to 0, it will only consider the connected neighbors as defined by the adjacency matrix. but if you set a value greater than 0, it will continue to fetch the closest points up to this many, excluding the ones already specified by the adjacency matrix
    max_sparse_neighbors = 8         # you can cap the number of neighbors, sampled from within your sparse set of neighbors as defined by the adjacency matrix, if specified
)

feats = torch.randn(1, 128, 32)
coors = torch.randn(1, 128, 3)
mask  = torch.ones(1, 128).bool()

# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)

i = torch.arange(128)
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

out = model(feats, coors, mask, adj_mat = adj_mat) # (1, 128, 512)

You can also have the network automatically derive for you the Nth-degree neighbors with one extra keyword num_adj_degrees. If you would like the system to differentiate between the degree of the neighbors as edge information, further pass in a non-zero adj_dim.

import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 1,
    attend_self = True,
    num_degrees = 2,
    output_degrees = 2,
    num_neighbors = 0,
    attend_sparse_neighbors = True,
    num_adj_degrees = 2,    # automatically derive 2nd degree neighbors
    adj_dim = 4             # embed 1st and 2nd degree neighbors (as well as null neighbors) with edge embeddings of this dimension
)

feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()

# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)

i = torch.arange(128)
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

out = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1)

To have fine control over the dimensionality of each type, you can use the hidden_fiber_dict and out_fiber_dict keywords to pass in a dictionary with the degree to dimension values as the key / values.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,
    edge_dim = 16,
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    hidden_fiber_dict = {0: 16, 1: 8, 2: 4},
    out_fiber_dict = {0: 16, 1: 1},
    reduce_dim_out = False
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds)

pred['0'] # (2, 32, 16)
pred['1'] # (2, 32, 1, 3)

Neighbors

You can further control which nodes can be considered by passing in a neighbor mask. All False values will be masked out of consideration.

import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 16,
    dim_head = 16,
    attend_self = True,
    num_degrees = 4,
    output_degrees = 2,
    num_edge_tokens = 4,
    num_neighbors = 8,      # make sure you set this value as the maximum number of neighbors set by your neighbor_mask, or it will throw a warning
    edge_dim = 2,
    depth = 3
)

feats = torch.randn(1, 32, 16)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()
bonds = torch.randint(0, 4, (1, 32, 32))

neighbor_mask = torch.ones(1, 32, 32).bool() # set the nodes you wish to be masked out as False

out = model(
    feats,
    coors,
    mask,
    edges = bonds,
    neighbor_mask = neighbor_mask,
    return_type = 1
)

Global Nodes

This feature allows you to pass in vectors that can be viewed as global nodes that are seen by all other nodes. The idea would be to pool your graph into a few feature vectors, which will be projected to key / values across all the attention layers in the network. All nodes will have full access to global node information, regardless of nearest neighbors or adjacency calculation.

import torch
from torch import nn
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 1,
    num_degrees = 2,
    num_neighbors = 4,
    valid_radius = 10,
    global_feats_dim = 32 # this must be set to the dimension of the global features, in this example, 32
)

feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()

# naively derive global features
# by pooling features and projecting
global_feats = nn.Linear(64, 32)(feats.mean(dim = 1, keepdim = True)) # (1, 1, 32)

out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)

Todo:

  • allow global nodes to attend to all other nodes, to give the network a global conduit for information. (Similar to BigBird, ETC, Longformer etc)

Autoregressive

You can use SE3 Transformers autoregressively with just one extra flag

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10,
    causal = True          # set this to True
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Experimental Features

Non-pairwise convolved keys

I've discovered that using linearly projected keys (rather than the pairwise convolution) seems to do ok in a toy denoising task. This leads to 25% memory savings. You can try this feature by setting linear_proj_keys = True

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 1,
    num_degrees = 4,
    num_neighbors = 8,
    valid_radius = 10,
    splits = 4,
    linear_proj_keys = True # set this to True
).cuda()

feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask  = torch.ones(1, 32).bool().cuda()

out = model(feats, coors, mask, return_type = 0)

Shared key / values across all heads

There is a relatively unknown technique for transformers where one can share one key / value head across all the heads of the queries. In my experience in NLP, this usually leads to worse performance, but if you are really in need to tradeoff memory for more depth or higher number of degrees, this may be a good option.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 8,
    num_degrees = 4,
    num_neighbors = 8,
    valid_radius = 10,
    splits = 4,
    one_headed_key_values = True  # one head of key / values shared across all heads of the queries
).cuda()

feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask  = torch.ones(1, 32).bool().cuda()

out = model(feats, coors, mask, return_type = 0)

Tied key / values

You can also tie the key / values (have them be the same), for half memory savings

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 8,
    num_degrees = 4,
    num_neighbors = 8,
    valid_radius = 10,
    splits = 4,
    tie_key_values = True # set this to True
).cuda()

feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask  = torch.ones(1, 32).bool().cuda()

out = model(feats, coors, mask, return_type = 0)

Using EGNN

This is an experimental version of EGNN that works for higher types, and greater dimensionality than just 1 (for the coordinates). The class name is still SE3Transformer since it reuses some preexisting logic, so just ignore that for now until I clean it up later.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 32,
    num_neighbors = 8,
    num_edge_tokens = 4,
    edge_dim = 4,
    num_degrees = 4,       # number of higher order types - will use basis on a TCN to project to these dimensions
    use_egnn = True,       # set this to true to use EGNN instead of equivariant attention layers
    egnn_hidden_dim = 64,  # egnn hidden dimension
    depth = 4,             # depth of EGNN
    reduce_dim_out = True  # will project the dimension of the higher types to 1
).cuda()

feats = torch.randn(2, 32, 32).cuda()
coors = torch.randn(2, 32, 3).cuda()
bonds = torch.randint(0, 4, (2, 32, 32)).cuda()
mask  = torch.ones(2, 32).bool().cuda()

refinement = model(feats, coors, mask, edges = bonds, return_type = 1) # (2, 32, 3)

coors = coors + refinement  # update coors with refinement

If you would like to specify individual dimensions for each of the higher types, just pass in hidden_fiber_dict where the dictionary is in the format {<degree>:<dim>} instead of num_degrees

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 32,
    num_neighbors = 8,
    hidden_fiber_dict = {0: 32, 1: 16, 2: 8, 3: 4},
    use_egnn = True,
    depth = 4,
    egnn_hidden_dim = 64,
    egnn_weights_clamp_value = 2, 
    reduce_dim_out = True
).cuda()

feats = torch.randn(2, 32, 32).cuda()
coors = torch.randn(2, 32, 3).cuda()
mask  = torch.ones(2, 32).bool().cuda()

refinement = model(feats, coors, mask, return_type = 1) # (2, 32, 3)

coors = coors + refinement  # update coors with refinement

Scaling (wip)

This section will list ongoing efforts to make SE3 Transformer scale a little better.

Firstly, I have added reversible networks. This allows me to add a little more depth before hitting the usual memory roadblocks. Equivariance preservation is demonstrated in the tests.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 20,
    dim = 32,
    dim_head = 32,
    heads = 4,
    depth = 12,             # 12 layers
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True,
    reversible = True       # set reversible to True
).cuda()

atoms = torch.randint(0, 4, (2, 32)).cuda()
coors = torch.randn(2, 32, 3).cuda()
mask  = torch.ones(2, 32).bool().cuda()

pred = model(atoms, coors, mask = mask, return_type = 0)

loss = pred.sum()
loss.backward()

Examples

First install sidechainnet

$ pip install sidechainnet

Then run the protein backbone denoising task

$ python denoise.py

Caching

By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE to some value on initiating the script

$ CLEAR_CACHE=1 python train.py

Or you can try deleting the cache directory, which should exist at

$ rm -rf ~/.cache.equivariant_attention

You can also designate your own directory where you want the caches to be stored, in the case that the default directory may have permission issues

CACHE_PATH=./path/to/my/cache python train.py

Testing

$ python setup.py pytest

Credit

This library is largely a port of Fabian's official repository, but without the DGL library.

Citations

@misc{fuchs2020se3transformers,
    title   = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 
    author  = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year    = {2020},
    eprint  = {2006.10503},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{satorras2021en,
    title   = {E(n) Equivariant Graph Neural Networks},
    author  = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year    = {2021},
    eprint  = {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{gomez2017reversible,
    title     = {The Reversible Residual Network: Backpropagation Without Storing Activations},
    author    = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},
    year      = {2017},
    eprint    = {1707.04585},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{shazeer2019fast,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    year    = {2019},
    eprint  = {1911.02150},
    archivePrefix = {arXiv},
    primaryClass = {cs.NE}
}

More Repositories

1

vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Python
13,633
star
2

DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
Python
11,068
star
3

imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
Python
7,832
star
4

PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
Python
7,611
star
5

DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
Python
5,132
star
6

deep-daze

Simple command line tool for text to image generation using OpenAI's CLIP and Siren (Implicit neural representation network). Technique was originally created by https://twitter.com/advadnoun
Python
4,387
star
7

denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
Python
3,959
star
8

stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
Python
3,433
star
9

musiclm-pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch
Python
3,048
star
10

x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
Python
2,707
star
11

big-sleep

A simple command line tool for text to image generation, using OpenAI's CLIP and a BigGAN. Technique was originally created by https://twitter.com/advadnoun
Python
2,446
star
12

audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
Python
2,285
star
13

lion-pytorch

šŸ¦ Lion, new optimizer discovered by Google Brain using genetic algorithms that is purportedly better than Adam(w), in Pytorch
Python
1,933
star
14

toolformer-pytorch

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
Python
1,905
star
15

reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
Python
1,870
star
16

make-a-video-pytorch

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch
Python
1,853
star
17

gigagan-pytorch

Implementation of GigaGAN, new SOTA GAN out of Adobe. Culmination of nearly a decade of research into GANs
Python
1,632
star
18

alphafold2

To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released
Python
1,536
star
19

lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
Python
1,526
star
20

lambda-networks

Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute
Python
1,516
star
21

byol-pytorch

Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch
Python
1,497
star
22

self-rewarding-lm-pytorch

Implementation of the training framework proposed in Self-Rewarding Language Model, from MetaAI
Python
1,253
star
23

naturalspeech2-pytorch

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch
Python
1,214
star
24

flamingo-pytorch

Implementation of šŸ¦© Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch
Python
1,155
star
25

video-diffusion-pytorch

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch
Python
1,141
star
26

soundstorm-pytorch

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch
Python
1,130
star
27

CoCa-pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch
Python
990
star
28

performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
Python
937
star
29

perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
Python
935
star
30

RETRO-pytorch

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch
Python
835
star
31

mlp-mixer-pytorch

An All-MLP solution for Vision, from Google AI
Python
833
star
32

muse-maskgit-pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
Python
821
star
33

PaLM-pytorch

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways
Python
812
star
34

vector-quantize-pytorch

Vector Quantization, in Pytorch
Python
810
star
35

phenaki-pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch
Python
724
star
36

x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers
Python
658
star
37

bottleneck-transformer-pytorch

Implementation of Bottleneck Transformer in Pytorch
Python
632
star
38

memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
Python
614
star
39

TimeSformer-pytorch

Implementation of TimeSformer from Facebook AI, a pure attention-based solution for video classification
Python
613
star
40

MEGABYTE-pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
Python
594
star
41

meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
Python
564
star
42

nuwa-pytorch

Implementation of NƜWA, state of the art attention network for text to video synthesis, in Pytorch
Python
531
star
43

voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
Python
521
star
44

point-transformer-pytorch

Implementation of the Point Transformer layer, in Pytorch
Python
518
star
45

parti-pytorch

Implementation of Parti, Google's pure attention-based text-to-image neural network, in Pytorch
Python
509
star
46

tab-transformer-pytorch

Implementation of TabTransformer, attention network for tabular data, in Pytorch
Python
485
star
47

alphafold3-pytorch

Implementation of Alphafold 3 in Pytorch
Python
483
star
48

linear-attention-transformer

Transformer based on a variant of attention that is linear complexity in respect to sequence length
Python
468
star
49

magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
Python
436
star
50

ema-pytorch

A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
Python
408
star
51

egnn-pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch
Python
400
star
52

g-mlp-pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Python
391
star
53

recurrent-memory-transformer-pytorch

Implementation of Recurrent Memory Transformer, Neurips 2022 paper, in Pytorch
Python
384
star
54

ring-attention-pytorch

Implementation of šŸ’ Ring Attention, from Liu et al. at Berkeley AI, in Pytorch
Python
380
star
55

siren-pytorch

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
Python
377
star
56

enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
Python
352
star
57

iTransformer

Unofficial implementation of iTransformer - SOTA Time Series Forecasting using Attention networks, out of Tsinghua / Ant group
Python
349
star
58

robotic-transformer-pytorch

Implementation of RT1 (Robotic Transformer) in Pytorch
Python
346
star
59

memory-efficient-attention-pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(nĀ²) Memory"
Python
342
star
60

FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
Python
334
star
61

bit-diffusion

Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch
Python
313
star
62

medical-chatgpt

Implementation of ChatGPT, but tailored towards primary care medicine, with the reward being able to collect patient histories in a thorough and efficient manner and come up with a reasonable differential diagnosis
Python
311
star
63

slot-attention

Implementation of Slot Attention from GoogleAI
Python
303
star
64

q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
Python
293
star
65

BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
Python
289
star
66

classifier-free-guidance-pytorch

Implementation of Classifier Free Guidance in Pytorch, with emphasis on text conditioning, and flexibility to include multiple text embedding models
Python
282
star
67

transformer-in-transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Python
277
star
68

axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
Python
273
star
69

conformer

Implementation of the convolutional module from the Conformer paper, for use in Transformers
Python
272
star
70

mixture-of-experts

A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
Python
264
star
71

deformable-attention

Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"
Python
258
star
72

magic3d-pytorch

Implementation of Magic3D, Text to 3D content synthesis, in Pytorch
Python
258
star
73

x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings
Python
252
star
74

routing-transformer

Fully featured implementation of Routing Transformer
Python
251
star
75

Adan-pytorch

Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch
Python
245
star
76

spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
Python
241
star
77

st-moe-pytorch

Implementation of ST-Moe, the latest incarnation of MoE after years of research at Brain, in Pytorch
Python
237
star
78

perfusion-pytorch

Implementation of Key-Locked Rank One Editing, from Nvidia AI
Python
229
star
79

equiformer-pytorch

Implementation of the Equiformer, SE3/E3 equivariant attention network that reaches new SOTA, and adopted for use by EquiFold for protein folding
Python
227
star
80

segformer-pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Python
227
star
81

sinkhorn-transformer

Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention
Python
222
star
82

pixel-level-contrastive-learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch
Python
220
star
83

lumiere-pytorch

Implementation of Lumiere, SOTA text-to-video generation from Google Deepmind, in Pytorch
Python
216
star
84

local-attention

An implementation of local windowed attention for language modeling
Python
216
star
85

CoLT5-attention

Implementation of the conditionally routed attention in the CoLT5 architecture, in Pytorch
Python
216
star
86

natural-speech-pytorch

Implementation of the neural network proposed in Natural Speech, a text-to-speech generator that is indistinguishable from human recordings for the first time, from Microsoft Research
Python
215
star
87

soft-moe-pytorch

Implementation of Soft MoE, proposed by Brain's Vision team, in Pytorch
Python
211
star
88

block-recurrent-transformer-pytorch

Implementation of Block Recurrent Transformer - Pytorch
Python
205
star
89

Mega-pytorch

Implementation of Mega, the Single-head Attention with Multi-headed EMA architecture that currently holds SOTA on Long Range Arena
Python
201
star
90

simple-hierarchical-transformer

Experiments around a simple idea for inducing multiple hierarchical predictive model within a GPT
Python
198
star
91

med-seg-diff-pytorch

Implementation of MedSegDiff in Pytorch - SOTA medical segmentation using DDPM and filtering of features in fourier space
Python
195
star
92

triton-transformer

Implementation of a Transformer, but completely in Triton
Python
195
star
93

jax2torch

Use Jax functions in Pytorch
Python
194
star
94

flash-cosine-sim-attention

Implementation of fused cosine similarity attention in the same style as Flash Attention
Cuda
194
star
95

halonet-pytorch

Implementation of the šŸ˜‡ Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Python
193
star
96

attention

This repository will house a visualization that will attempt to convey instant enlightenment of how Attention works to someone not working in artificial intelligence, with 3Blue1Brown as inspiration
HTML
189
star
97

recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
Python
188
star
98

electra-pytorch

A simple and working implementation of Electra, the fastest way to pretrain language models from scratch, in Pytorch
Python
186
star
99

PaLM-jax

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
Python
184
star
100

unet-stylegan2

A Pytorch implementation of Stylegan2 with UNet Discriminator
Python
182
star