• Stars
    star
    531
  • Rank 83,526 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

NÜWA - Pytorch

Join us on Discord

Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. It also contain an extension into video and audio generation, using a dual decoder approach.

Yannic Kilcher

DeepReader

Status

  • March 2022 - seeing signs of life with a difficult version of moving mnist

  • April 2022 - It seems as though a diffusion based method has taken the new throne for SOTA. However, I will continue on with NUWA, extending it to use multi-headed codes + hierarchical causal transformer. I think that direction is untapped for improving on this line of work.

Install

$ pip install nuwa-pytorch

Usage

First train the VAE

import torch
from nuwa_pytorch import VQGanVAE

vae = VQGanVAE(
    dim = 512,
    channels = 3,               # default is 3, but can be changed to any value for the training of the segmentation masks (sketches)
    image_size = 256,           # image size
    num_layers = 4,             # number of downsampling layers
    num_resnet_blocks = 2,      # number of resnet blocks
    vq_codebook_size = 8192,    # codebook size
    vq_decay = 0.8              # codebook exponential decay
)

imgs = torch.randn(10, 3, 256, 256)

# alternate learning for autoencoder ...

loss = vae(imgs, return_loss = True)
loss.backward()

# and the discriminator ...

discr_loss = vae(imgs, return_discr_loss = True)
discr_loss.backward()

# do above for many steps

# return reconstructed images and make sure they look ok

recon_imgs = vae(imgs)

Then, with your learned VAE

import torch
from nuwa_pytorch import NUWA, VQGanVAE

# autoencoder

vae = VQGanVAE(
    dim = 64,
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 8192
)

# NUWA transformer

nuwa = NUWA(
    vae = vae,
    dim = 512,
    text_num_tokens = 20000,                # number of text tokens
    text_enc_depth = 12,                    # text encoder depth
    text_enc_heads = 8,                     # number of attention heads for encoder
    text_max_seq_len = 256,                 # max sequence length of text conditioning tokens (keep at 256 as in paper, or shorter, if your text is not that long)
    max_video_frames = 10,                  # number of video frames
    image_size = 256,                       # size of each frame of video
    dec_depth = 64,                         # video decoder depth
    dec_heads = 8,                          # number of attention heads in decoder
    dec_reversible = True,                  # reversible networks - from reformer, decoupling memory usage from depth
    enc_reversible = True,                  # reversible encoders, if you need it
    attn_dropout = 0.05,                    # dropout for attention
    ff_dropout = 0.05,                      # dropout for feedforward
    sparse_3dna_kernel_size = (5, 3, 3),    # kernel size of the sparse 3dna attention. can be a single value for frame, height, width, or different values (to simulate axial attention, etc)
    sparse_3dna_dilation = (1, 2, 4),       # cycle dilation of 3d conv attention in decoder, for more range
    shift_video_tokens = True               # cheap relative positions for sparse 3dna transformer, by shifting along spatial dimensions by one
).cuda()

# data

text = torch.randint(0, 20000, (1, 256)).cuda()
video = torch.randn(1, 10, 3, 256, 256).cuda() # (batch, frames, channels, height, width)

loss = nuwa(
    text = text,
    video = video,
    return_loss = True  # set this to True, only for training, to return cross entropy loss
)

loss.backward()

# do above with as much data as possible

# then you can generate a video from text

video = nuwa.generate(text = text, num_frames = 5) # (1, 5, 3, 256, 256)

Conditioning on Sketches

In the paper, they also present a way to condition the video generation based on segmentation mask(s). You can easily do this as well, given you train a VQGanVAE on the sketches before hand.

Then, you will use NUWASketch instead of NUWA, which can accept the sketch VAE as a reference

ex.

import torch
from nuwa_pytorch import NUWASketch, VQGanVAE

# autoencoder, one for main video, the other for the sketch

vae = VQGanVAE(
    dim = 64,
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 8192
)

sketch_vae = VQGanVAE(
    dim = 512,
    channels = 5,                # say the sketch has 5 classes
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 8192
)

# NUWA transformer for conditioning with sketches

nuwa = NUWASketch(
    vae = vae,
    sketch_vae = sketch_vae,
    dim = 512,                              # model dimensions
    sketch_enc_depth = 12,                  # sketch encoder depth
    sketch_enc_heads = 8,                   # number of attention heads for sketch encoder
    sketch_max_video_frames = 3,            # max number of frames for sketches
    sketch_enc_use_sparse_3dna = True,      # whether to use 3d-nearby attention (of full attention if False) for sketch encoding transformer
    max_video_frames = 10,                  # number of video frames
    image_size = 256,                       # size of each frame of video
    dec_depth = 64,                         # video decoder depth
    dec_heads = 8,                          # number of attention heads in decoder
    dec_reversible = True,                  # reversible networks - from reformer, decoupling memory usage from depth
    enc_reversible = True,                  # reversible encoders, if you need it
    attn_dropout = 0.05,                    # dropout for attention
    ff_dropout = 0.05,                      # dropout for feedforward
    sparse_3dna_kernel_size = (5, 3, 3),    # kernel size of the sparse 3dna attention. can be a single value for frame, height, width, or different values (to simulate axial attention, etc)
    sparse_3dna_dilation = (1, 2, 4),       # cycle dilation of 3d conv attention in decoder, for more range
    cross_2dna_kernel_size = 5,             # 2d kernel size of spatial grouping of attention from video frames to sketches
    cross_2dna_dilation = 1,                # 2d dilation of spatial attention from video frames to sketches
    shift_video_tokens = True               # cheap relative positions for sparse 3dna transformer, by shifting along spatial dimensions by one
).cuda()

# data

sketch = torch.randn(2, 2, 5, 256, 256).cuda() # (batch, frames, segmentation classes, height, width)
sketch_mask = torch.ones(2, 2).bool().cuda()   # (batch, frames) [Optional]
video = torch.randn(2, 10, 3, 256, 256).cuda() # (batch, frames, channels, height, width)

loss = nuwa(
    sketch = sketch,
    sketch_mask =sketch_mask,
    video = video,
    return_loss = True  # set this to True, only for training, to return cross entropy loss
)

loss.backward()

# do above with as much data as possible

# then you can generate a video from sketch(es)

video = nuwa.generate(sketch = sketch, num_frames = 5) # (1, 5, 3, 256, 256)

Text to Video and Audio

This repository will also offer a variant of NUWA that can produce both video and audio. For now, the audio will need to be encoded manually.

import torch
from nuwa_pytorch import NUWAVideoAudio, VQGanVAE

# autoencoder

vae = VQGanVAE(
    dim = 64,
    num_layers = 4,
    image_size = 256,
    num_conv_blocks = 2,
    vq_codebook_size = 100
)

# NUWA transformer

nuwa = NUWAVideoAudio(
    vae = vae,
    dim = 512,
    num_audio_tokens = 2048,                # codebook size for audio tokens
    num_audio_tokens_per_video_frame = 32,  # number of audio tokens per video frame
    cross_modality_attn_every = 3,          # cross modality attention every N layers
    text_num_tokens = 20000,                # number of text tokens
    text_enc_depth = 1,                     # text encoder depth
    text_enc_heads = 8,                     # number of attention heads for encoder
    text_max_seq_len = 256,                 # max sequence length of text conditioning tokens (keep at 256 as in paper, or shorter, if your text is not that long)
    max_video_frames = 10,                  # number of video frames
    image_size = 256,                       # size of each frame of video
    dec_depth = 4,                          # video decoder depth
    dec_heads = 8,                          # number of attention heads in decoder
    enc_reversible = True,                  # reversible encoders, if you need it
    dec_reversible = True,                  # quad-branched reversible network, for making depth of twin video / audio decoder independent of network depth. recommended to be turned on unless you have a ton of memory at your disposal
    attn_dropout = 0.05,                    # dropout for attention
    ff_dropout = 0.05,                      # dropout for feedforward
    sparse_3dna_kernel_size = (5, 3, 3),    # kernel size of the sparse 3dna attention. can be a single value for frame, height, width, or different values (to simulate axial attention, etc)
    sparse_3dna_dilation = (1, 2, 4),       # cycle dilation of 3d conv attention in decoder, for more range
    shift_video_tokens = True               # cheap relative positions for sparse 3dna transformer, by shifting along spatial dimensions by one
).cuda()

# data

text = torch.randint(0, 20000, (1, 256)).cuda()
audio = torch.randint(0, 2048, (1, 32 * 10)).cuda() # (batch, audio tokens per frame * max video frames)
video = torch.randn(1, 10, 3, 256, 256).cuda() # (batch, frames, channels, height, width)

loss = nuwa(
    text = text,
    video = video,
    audio = audio,
    return_loss = True  # set this to True, only for training, to return cross entropy loss
)

loss.backward()

# do above with as much data as possible

# then you can generate a video from text

video, audio = nuwa.generate(text = text, num_frames = 5) # (1, 5, 3, 256, 256), (1, 32 * 5 == 160)

Trainers

This library will offer some utilities to make training easier. For starters, you can use the VQGanVAETrainer class to take care of training the VQGanVAE. Simply wrap the model and also pass in the image folder path as well as the various training hyperparameters.

import torch
from nuwa_pytorch import VQGanVAE, VQGanVAETrainer

vae = VQGanVAE(
    dim = 64,
    image_size = 256,
    num_layers = 5,
    vq_codebook_size = 1024,
    vq_use_cosine_sim = True,
    vq_codebook_dim = 32,
    vq_orthogonal_reg_weight = 10,
    vq_orthogonal_reg_max_codes = 128,
).cuda()

trainer = VQGanVAETrainer(
    vae,                           # VAE defined above
    folder ='/path/to/images',     # path to images
    lr = 3e-4,                     # learning rate
    num_train_steps = 100000,      # number of training steps
    batch_size = 8,                # batch size
    grad_accum_every = 4           # gradient accumulation (effective batch size is (batch_size x grad_accum_every))
)

trainer.train()

# results and model checkpoints will be saved periodically to ./results

To train NUWA, first you need to organize a folder of .gif files with corresponding .txt files containing its caption. It should be organized as such.

ex.

📂video-and-text-data
 ┣ 📜cat.gif
 ┣ 📜cat.txt
 ┣ 📜dog.gif
 ┣ 📜dog.txt
 ┣ 📜turtle.gif
 ┗ 📜turtle.txt

Then you will load your previously trained VQGan-VAE and train NUWA with the GifVideoDataset and NUWATrainer classes.

import torch
from nuwa_pytorch import NUWA, VQGanVAE
from nuwa_pytorch.train_nuwa import GifVideoDataset, NUWATrainer

# dataset

ds = GifVideoDataset(
    folder = './path/to/videos/',
    channels = 1
)

# autoencoder

vae = VQGanVAE(
    dim = 64,
    image_size = 256,
    num_layers = 5,
    num_resnet_blocks = 2,
    vq_codebook_size = 512,
    attn_dropout = 0.1
)

vae.load_state_dict(torch.load('./path/to/trained/vae.pt'))

# NUWA transformer

nuwa = NUWA(
    vae = vae,
    dim = 512,
    text_enc_depth = 6,
    text_max_seq_len = 256,
    max_video_frames = 10,
    dec_depth = 12,
    dec_reversible = True,
    enc_reversible = True,
    attn_dropout = 0.05,
    ff_dropout = 0.05,
    sparse_3dna_kernel_size = (5, 3, 3),
    sparse_3dna_dilation = (1, 2, 4),
    shift_video_tokens = True
).cuda()

# data

trainer = NUWATrainer(
    nuwa = nuwa,                 # NUWA transformer
    dataset = dataset,           # video dataset class
    num_train_steps = 1000000,   # number of training steps
    lr = 3e-4,                   # learning rate
    wd = 0.01,                   # weight decay
    batch_size = 8,              # batch size
    grad_accum_every = 4,        # gradient accumulation
    max_grad_norm = 0.5,         # gradient clipping
    num_sampled_frames = 10,     # number of frames to sample
    results_folder = './results' # folder to store checkpoints and samples
)

trainer.train()

VQ improvements

This library depends on this vector quantization library, which comes with a number of improvements (improved vqgan, orthogonal codebook regularization, etc). To use any of these improvements, you can configure the vector quantizer keyword params by prepending vq_ on VQGanVAE initialization.

ex. cosine sim proposed in improved vqgan

from nuwa_pytorch import VQGanVAE

vae = VQGanVAE(
    dim = 64,
    image_size = 256,
    num_layers = 4,
    vq_use_cosine_sim = True
    # VectorQuantize will be initialized with use_cosine_sim = True
    # https://github.com/lucidrains/vector-quantize-pytorch#cosine-similarity
).cuda()

Todo

  • complete 3dna causal attention in decoder
  • write up easy generation functions
  • make sure GAN portion of VQGan is correct, reread paper
  • make sure adaptive weight in vqgan is correctly built
  • offer new vqvae improvements (orthogonal reg and smaller codebook dimensions)
  • batch video tokens -> vae during video generation, to prevent oom
  • query chunking in 3dna attention, to put a cap on peak memory
  • flesh out VAE resnet blocks, offer some choices
  • add all stability tricks from cogview paper by default
  • make VQGan able to accept custom VGG for LPAPs loss (audio)
  • add feedforward chunking
  • add shift token in decoder for cheap powerful RPE
  • add reversible networks, to save on memory on depth
  • support kernel sizes different along each dimension for sparse 3dna
  • add some autotrainer that takes care of the alternating updates of discriminator and VQVAE generator
  • segmentation mask encoder, make sure embeddings can undergo 3dna attention with decoder during cross attention
  • finish 2d-nearby cross attention for sketches
  • able to add convnext blocks to other layers in vqgan vae
  • offer vqvae training script
  • handle variable lengthed sketches, accept a mask on the sketch frames dimension
  • take care of audio transformer and cross modality attention
  • add audio transformer, and build audio / video nearby cross attention
  • make dual decoder reversible
  • rotary embeddings for encoder
  • add cycle dilation to audio
  • omit vgg from VAE state dict
  • add cosine sim attention from swinv2 as an option
  • add axial positional embedding to audio
  • Triton kernel for 3dna attention
  • offer a colab with moving mnist example, conditioned on present digits
  • build NUWA controller class that can accept text or sketch
  • key masking for 3dna attention - for variable sketch length masking
  • figure out spec vqgan and fit it into the framework, take care of audio encoding / decoding automatically
  • turn into CLI tool, like stylegan2-pytorch
  • look into integrating https://github.com/lucidrains/RQ-Transformer for both video and audio
  • inference caching

Citations

@misc{wu2021nuwa,
    title   = {N\"UWA: Visual Synthesis Pre-training for Neural visUal World creAtion}, 
    author  = {Chenfei Wu and Jian Liang and Lei Ji and Fan Yang and Yuejian Fang and Daxin Jiang and Nan Duan},
    year    = {2021},
    eprint  = {2111.12417},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{esser2021taming,
    title   = {Taming Transformers for High-Resolution Image Synthesis},
    author  = {Patrick Esser and Robin Rombach and Björn Ommer},
    year    = {2021},
    eprint  = {2012.09841},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{iashin2021taming,
    title   = {Taming Visually Guided Sound Generation},
    author  = {Vladimir Iashin and Esa Rahtu},
    year    = {2021},
    eprint  = {2110.08791},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{kitaev2020reformer,
    title   = {Reformer: The Efficient Transformer},
    author  = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya},
    year    = {2020},
    eprint  = {2001.04451},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{shazeer2020talkingheads,
    title   = {Talking-Heads Attention}, 
    author  = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    year    = {2020},
    eprint  = {2003.02436},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}    
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{ho2021classifierfree,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho and Tim Salimans},
    booktitle = {NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications},
    year    = {2021},
    url     = {https://openreview.net/forum?id=qw8AKxfYbI}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/RiversHaveWings/status/1478093658716966912}
}

Attention is the rarest and purest form of generosity. - Simone Weil

More Repositories

1

vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Python
13,633
star
2

DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
Python
11,068
star
3

imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
Python
7,832
star
4

PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
Python
7,611
star
5

DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
Python
5,132
star
6

deep-daze

Simple command line tool for text to image generation using OpenAI's CLIP and Siren (Implicit neural representation network). Technique was originally created by https://twitter.com/advadnoun
Python
4,387
star
7

denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
Python
3,959
star
8

stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
Python
3,433
star
9

musiclm-pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch
Python
3,048
star
10

x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
Python
2,707
star
11

big-sleep

A simple command line tool for text to image generation, using OpenAI's CLIP and a BigGAN. Technique was originally created by https://twitter.com/advadnoun
Python
2,446
star
12

audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
Python
2,285
star
13

lion-pytorch

🦁 Lion, new optimizer discovered by Google Brain using genetic algorithms that is purportedly better than Adam(w), in Pytorch
Python
1,933
star
14

toolformer-pytorch

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
Python
1,905
star
15

reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
Python
1,870
star
16

make-a-video-pytorch

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch
Python
1,853
star
17

gigagan-pytorch

Implementation of GigaGAN, new SOTA GAN out of Adobe. Culmination of nearly a decade of research into GANs
Python
1,632
star
18

alphafold2

To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released
Python
1,536
star
19

lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
Python
1,526
star
20

lambda-networks

Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute
Python
1,516
star
21

byol-pytorch

Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch
Python
1,497
star
22

self-rewarding-lm-pytorch

Implementation of the training framework proposed in Self-Rewarding Language Model, from MetaAI
Python
1,253
star
23

naturalspeech2-pytorch

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch
Python
1,214
star
24

flamingo-pytorch

Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch
Python
1,155
star
25

video-diffusion-pytorch

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch
Python
1,141
star
26

soundstorm-pytorch

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch
Python
1,130
star
27

CoCa-pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch
Python
990
star
28

performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
Python
937
star
29

perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
Python
935
star
30

RETRO-pytorch

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch
Python
835
star
31

mlp-mixer-pytorch

An All-MLP solution for Vision, from Google AI
Python
833
star
32

muse-maskgit-pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
Python
821
star
33

PaLM-pytorch

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways
Python
812
star
34

vector-quantize-pytorch

Vector Quantization, in Pytorch
Python
810
star
35

phenaki-pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch
Python
724
star
36

x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers
Python
658
star
37

bottleneck-transformer-pytorch

Implementation of Bottleneck Transformer in Pytorch
Python
632
star
38

memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
Python
614
star
39

TimeSformer-pytorch

Implementation of TimeSformer from Facebook AI, a pure attention-based solution for video classification
Python
613
star
40

MEGABYTE-pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
Python
594
star
41

meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
Python
564
star
42

voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
Python
521
star
43

point-transformer-pytorch

Implementation of the Point Transformer layer, in Pytorch
Python
518
star
44

parti-pytorch

Implementation of Parti, Google's pure attention-based text-to-image neural network, in Pytorch
Python
509
star
45

tab-transformer-pytorch

Implementation of TabTransformer, attention network for tabular data, in Pytorch
Python
485
star
46

alphafold3-pytorch

Implementation of Alphafold 3 in Pytorch
Python
483
star
47

linear-attention-transformer

Transformer based on a variant of attention that is linear complexity in respect to sequence length
Python
468
star
48

magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
Python
436
star
49

ema-pytorch

A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
Python
408
star
50

egnn-pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch
Python
400
star
51

g-mlp-pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Python
391
star
52

recurrent-memory-transformer-pytorch

Implementation of Recurrent Memory Transformer, Neurips 2022 paper, in Pytorch
Python
384
star
53

ring-attention-pytorch

Implementation of 💍 Ring Attention, from Liu et al. at Berkeley AI, in Pytorch
Python
380
star
54

siren-pytorch

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
Python
377
star
55

enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
Python
352
star
56

iTransformer

Unofficial implementation of iTransformer - SOTA Time Series Forecasting using Attention networks, out of Tsinghua / Ant group
Python
349
star
57

robotic-transformer-pytorch

Implementation of RT1 (Robotic Transformer) in Pytorch
Python
346
star
58

memory-efficient-attention-pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"
Python
342
star
59

FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
Python
334
star
60

bit-diffusion

Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch
Python
313
star
61

medical-chatgpt

Implementation of ChatGPT, but tailored towards primary care medicine, with the reward being able to collect patient histories in a thorough and efficient manner and come up with a reasonable differential diagnosis
Python
311
star
62

slot-attention

Implementation of Slot Attention from GoogleAI
Python
303
star
63

q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
Python
293
star
64

BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
Python
289
star
65

classifier-free-guidance-pytorch

Implementation of Classifier Free Guidance in Pytorch, with emphasis on text conditioning, and flexibility to include multiple text embedding models
Python
282
star
66

transformer-in-transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Python
277
star
67

axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
Python
273
star
68

conformer

Implementation of the convolutional module from the Conformer paper, for use in Transformers
Python
272
star
69

mixture-of-experts

A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
Python
264
star
70

deformable-attention

Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"
Python
258
star
71

magic3d-pytorch

Implementation of Magic3D, Text to 3D content synthesis, in Pytorch
Python
258
star
72

x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings
Python
252
star
73

routing-transformer

Fully featured implementation of Routing Transformer
Python
251
star
74

Adan-pytorch

Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch
Python
245
star
75

spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
Python
241
star
76

st-moe-pytorch

Implementation of ST-Moe, the latest incarnation of MoE after years of research at Brain, in Pytorch
Python
237
star
77

perfusion-pytorch

Implementation of Key-Locked Rank One Editing, from Nvidia AI
Python
229
star
78

equiformer-pytorch

Implementation of the Equiformer, SE3/E3 equivariant attention network that reaches new SOTA, and adopted for use by EquiFold for protein folding
Python
227
star
79

segformer-pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Python
227
star
80

sinkhorn-transformer

Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention
Python
222
star
81

pixel-level-contrastive-learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch
Python
220
star
82

lumiere-pytorch

Implementation of Lumiere, SOTA text-to-video generation from Google Deepmind, in Pytorch
Python
216
star
83

local-attention

An implementation of local windowed attention for language modeling
Python
216
star
84

CoLT5-attention

Implementation of the conditionally routed attention in the CoLT5 architecture, in Pytorch
Python
216
star
85

natural-speech-pytorch

Implementation of the neural network proposed in Natural Speech, a text-to-speech generator that is indistinguishable from human recordings for the first time, from Microsoft Research
Python
215
star
86

soft-moe-pytorch

Implementation of Soft MoE, proposed by Brain's Vision team, in Pytorch
Python
211
star
87

se3-transformer-pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
Python
211
star
88

block-recurrent-transformer-pytorch

Implementation of Block Recurrent Transformer - Pytorch
Python
205
star
89

Mega-pytorch

Implementation of Mega, the Single-head Attention with Multi-headed EMA architecture that currently holds SOTA on Long Range Arena
Python
201
star
90

simple-hierarchical-transformer

Experiments around a simple idea for inducing multiple hierarchical predictive model within a GPT
Python
198
star
91

med-seg-diff-pytorch

Implementation of MedSegDiff in Pytorch - SOTA medical segmentation using DDPM and filtering of features in fourier space
Python
195
star
92

triton-transformer

Implementation of a Transformer, but completely in Triton
Python
195
star
93

jax2torch

Use Jax functions in Pytorch
Python
194
star
94

flash-cosine-sim-attention

Implementation of fused cosine similarity attention in the same style as Flash Attention
Cuda
194
star
95

halonet-pytorch

Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Python
193
star
96

attention

This repository will house a visualization that will attempt to convey instant enlightenment of how Attention works to someone not working in artificial intelligence, with 3Blue1Brown as inspiration
HTML
189
star
97

recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
Python
188
star
98

electra-pytorch

A simple and working implementation of Electra, the fastest way to pretrain language models from scratch, in Pytorch
Python
186
star
99

PaLM-jax

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
Python
184
star
100

unet-stylegan2

A Pytorch implementation of Stylegan2 with UNet Discriminator
Python
182
star