• Stars
    star
    3,433
  • Rank 12,422 (Top 0.3 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement

Simple StyleGan2 for Pytorch

PyPI version

Simple Pytorch implementation of Stylegan2 based on https://arxiv.org/abs/1912.04958 that can be completely trained from the command-line, no coding needed.

Below are some flowers that do not exist.

Neither do these hands

Nor these cities

Nor these celebrities (trained by @yoniker)

Install

You will need a machine with a GPU and CUDA installed. Then pip install the package like this

$ pip install stylegan2_pytorch

If you are using a windows machine, the following commands reportedly works.

$ conda install pytorch torchvision -c python
$ pip install stylegan2_pytorch

Use

$ stylegan2_pytorch --data /path/to/images

That's it. Sample images will be saved to results/default and models will be saved periodically to models/default.

Advanced Use

You can specify the name of your project with

$ stylegan2_pytorch --data /path/to/images --name my-project-name

You can also specify the location where intermediate results and model checkpoints should be stored with

$ stylegan2_pytorch --data /path/to/images --name my-project-name --results_dir /path/to/results/dir --models_dir /path/to/models/dir

You can increase the network capacity (which defaults to 16) to improve generation results, at the cost of more memory.

$ stylegan2_pytorch --data /path/to/images --network-capacity 256

By default, if the training gets cut off, it will automatically resume from the last checkpointed file. If you want to restart with new settings, just add a new flag

$ stylegan2_pytorch --new --data /path/to/images --name my-project-name --image-size 512 --batch-size 1 --gradient-accumulate-every 16 --network-capacity 10

Once you have finished training, you can generate images from your latest checkpoint like so.

$ stylegan2_pytorch  --generate

To generate a video of a interpolation through two random points in latent space.

$ stylegan2_pytorch --generate-interpolation --interpolation-num-steps 100

To save each individual frame of the interpolation

$ stylegan2_pytorch --generate-interpolation --save-frames

If a previous checkpoint contained a better generator, (which often happens as generators start degrading towards the end of training), you can load from a previous checkpoint with another flag

$ stylegan2_pytorch --generate --load-from {checkpoint number}

A technique used in both StyleGAN and BigGAN is truncating the latent values so that their values fall close to the mean. The small the truncation value, the better the samples will appear at the cost of sample variety. You can control this with the --trunc-psi, where values typically fall between 0.5 and 1. It is set at 0.75 as default

$ stylegan2_pytorch --generate --trunc-psi 0.5

Multi-GPU training

If you have one machine with multiple GPUs, the repository offers a way to utilize all of them for training. With multiple GPUs, each batch will be divided evenly amongst the GPUs available. For example, for 2 GPUs, with a batch size of 32, each GPU will see 16 samples.

You simply have to add a --multi-gpus flag, everyting else is taken care of. If you would like to restrict to specific GPUs, you can use the CUDA_VISIBLE_DEVICES environment variable to control what devices can be used. (ex. CUDA_VISIBLE_DEVICES=0,2,3 only devices 0, 2, 3 are available)

$ stylegan2_pytorch --data ./data --multi-gpus --batch-size 32 --gradient-accumulate-every 1

Low amounts of Training Data

In the past, GANs needed a lot of data to learn how to generate well. The faces model took 70k high quality images from Flickr, as an example.

However, in the month of May 2020, researchers all across the world independently converged on a simple technique to reduce that number to as low as 1-2k. That simple idea was to differentiably augment all images, generated or real, going into the discriminator during training.

If one were to augment at a low enough probability, the augmentations will not 'leak' into the generations.

In the setting of low data, you can use the feature with a simple flag.

# find a suitable probability between 0. -> 0.7 at maximum
$ stylegan2_pytorch --data ./data --aug-prob 0.25

By default, the augmentations used are translation and cutout. If you would like to add color, you can do so with the --aug-types argument.

# make sure there are no spaces between items!
$ stylegan2_pytorch --data ./data --aug-prob 0.25 --aug-types [translation,cutout,color]

You can customize it to any combination of the three you would like. The differentiable augmentation code was copied and slightly modified from here.

When do I stop training?

For as long as possible until the adversarial game between the two neural nets fall apart (we call this divergence). By default, the number of training steps is set to 150000 for 128x128 images, but you will certainly want this number to be higher if the GAN doesn't diverge by the end of training, or if you are training at a higher resolution.

$ stylegan2_pytorch --data ./data --image-size 512 --num-train-steps 1000000

Attention

This framework also allows for you to add an efficient form of self-attention to the designated layers of the discriminator (and the symmetric layer of the generator), which will greatly improve results. The more attention you can afford, the better!

# add self attention after the output of layer 1
$ stylegan2_pytorch --data ./data --attn-layers 1
# add self attention after the output of layers 1 and 2
# do not put a space after the comma in the list!
$ stylegan2_pytorch --data ./data --attn-layers [1,2]

Bonus

Training on transparent images

$ stylegan2_pytorch --data ./transparent/images/path --transparent

Memory considerations

The more GPU memory you have, the bigger and better the image generation will be. Nvidia recommended having up to 16GB for training 1024x1024 images. If you have less than that, there are a couple settings you can play with so that the model fits.

$ stylegan2_pytorch --data /path/to/data \
    --batch-size 3 \
    --gradient-accumulate-every 5 \
    --network-capacity 16
  1. Batch size - You can decrease the batch-size down to 1, but you should increase the gradient-accumulate-every correspondingly so that the mini-batch the network sees is not too small. This may be confusing to a layperson, so I'll think about how I would automate the choice of gradient-accumulate-every going forward.

  2. Network capacity - You can decrease the neural network capacity to lessen the memory requirements. Just be aware that this has been shown to degrade generation performance.

If none of this works, you can settle for 'Lightweight' GAN, which will allow you to tradeoff quality to train at greater resolutions in reasonable amount of time.

Deployment on AWS

Below are some steps which may be helpful for deployment using Amazon Web Services. In order to use this, you will have to provision a GPU-backed EC2 instance. An appropriate instance type would be from a p2 or p3 series. I (iboates) tried a p2.xlarge (the cheapest option) and it was quite slow, slower in fact than using Google Colab. More powerful instance types may be better but they are more expensive. You can read more about them here.

Setup steps

  1. Archive your training data and upload it to an S3 bucket
  2. Provision your EC2 instance (I used an Ubuntu AMI)
  3. Log into your EC2 instance via SSH
  4. Install the aws CLI client and configure it:
sudo snap install aws-cli --classic
aws configure

You will then have to enter your AWS access keys, which you can retrieve from the management console under AWS Management Console > Profile > My Security Credentials > Access Keys

Then, run these commands, or maybe put them in a shell script and execute that:

mkdir data
curl -O https://bootstrap.pypa.io/get-pip.py
sudo apt-get install python3-distutils
python3 get-pip.py
pip3 install stylegan2_pytorch
export PATH=$PATH:/home/ubuntu/.local/bin
aws s3 sync s3://<Your bucket name> ~/data
cd data
tar -xf ../train.tar.gz

Now you should be able to train by simplying calling stylegan2_pytorch [args].

Notes:

  • If you have a lot of training data, you may need to provision extra block storage via EBS.
  • Also, you may need to spread your data across multiple archives.
  • You should run this on a screen window so it won't terminate once you log out of the SSH session.

Research

FID Scores

Thanks to GetsEclectic, you can now calculate the FID score periodically! Again, made super simple with one extra argument, as shown below.

Firstly, install the pytorch_fid package

$ pip install pytorch-fid

Followed by

$ stylegan2_pytorch --data ./data --calculate-fid-every 5000

FID results will be logged to ./results/{name}/fid_scores.txt

Coding

If you would like to sample images programmatically, you can do so with the following simple ModelLoader class.

import torch
from torchvision.utils import save_image
from stylegan2_pytorch import ModelLoader

loader = ModelLoader(
    base_dir = '/path/to/directory',   # path to where you invoked the command line tool
    name = 'default'                   # the project name, defaults to 'default'
)

noise   = torch.randn(1, 512).cuda() # noise
styles  = loader.noise_to_styles(noise, trunc_psi = 0.7)  # pass through mapping network
images  = loader.styles_to_images(styles) # call the generator on intermediate style vectors

save_image(images, './sample.jpg') # save your images, or do whatever you desire

Logging to experiment tracker

To log the losses to an open source experiment tracker (Aim), you simply need to pass an extra flag like so.

$ stylegan2_pytorch --data ./data --log

Then, you need to make sure you have Docker installed. Following the instructions at Aim, you execute the following in your terminal.

$ aim up

Then open up your browser to the address and you should see

Experimental

Top-k Training for Generator

A new paper has produced evidence that by simply zero-ing out the gradient contributions from samples that are deemed fake by the discriminator, the generator learns significantly better, achieving new state of the art.

$ stylegan2_pytorch --data ./data --top-k-training

Gamma is a decay schedule that slowly decreases the topk from the full batch size to the target fraction of 50% (also modifiable hyperparameter).

$ stylegan2_pytorch --data ./data --top-k-training --generate-top-k-frac 0.5 --generate-top-k-gamma 0.99

Feature Quantization

A recent paper reported improved results if intermediate representations of the discriminator are vector quantized. Although I have not noticed any dramatic changes, I have decided to add this as a feature, so other minds out there can investigate. To use, you have to specify which layer(s) you would like to vector quantize. Default dictionary size is 256 and is also tunable.

# feature quantize layers 1 and 2, with a dictionary size of 512 each
# do not put a space after the comma in the list!
$ stylegan2_pytorch --data ./data --fq-layers [1,2] --fq-dict-size 512

Contrastive Loss Regularization

I have tried contrastive learning on the discriminator (in step with the usual GAN training) and possibly observed improved stability and quality of final results. You can turn on this experimental feature with a simple flag as shown below.

$ stylegan2_pytorch --data ./data --cl-reg

Relativistic Discriminator Loss

This was proposed in the Relativistic GAN paper to stabilize training. I have had mixed results, but will include the feature for those who want to experiment with it.

$ stylegan2_pytorch --data ./data --rel-disc-loss

Non-constant 4x4 Block

By default, the StyleGAN architecture styles a constant learned 4x4 block as it is progressively upsampled. This is an experimental feature that makes it so the 4x4 block is learned from the style vector w instead.

$ stylegan2_pytorch --data ./data --no-const

Dual Contrastive Loss

A recent paper has proposed that a novel contrastive loss between the real and fake logits can improve quality over other types of losses. (The default in this repository is hinge loss, and the paper shows a slight improvement)

$ stylegan2_pytorch --data ./data --dual-contrast-loss

Alternatives

Stylegan2 + Unet Discriminator

I have gotten really good results with a unet discriminator, but the architecturally change was too big to fit as an option in this repository. If you are aiming for perfection, feel free to try it.

If you would like me to give the royal treatment to some other GAN architecture (BigGAN), feel free to reach out at my email. Happy to hear your pitch.

Appreciation

Thank you to Matthew Mann for his inspiring simple port for Tensorflow 2.0

References

@article{Karras2019stylegan2,
    title   = {Analyzing and Improving the Image Quality of {StyleGAN}},
    author  = {Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
    journal = {CoRR},
    volume  = {abs/1912.04958},
    year    = {2019},
}
@misc{zhao2020feature,
    title   = {Feature Quantization Improves GAN Training},
    author  = {Yang Zhao and Chunyuan Li and Ping Yu and Jianfeng Gao and Changyou Chen},
    year    = {2020}
}
@misc{chen2020simple,
    title   = {A Simple Framework for Contrastive Learning of Visual Representations},
    author  = {Ting Chen and Simon Kornblith and Mohammad Norouzi and Geoffrey Hinton},
    year    = {2020}
}
@article{,
    title     = {Oxford 102 Flowers},
    author    = {Nilsback, M-E. and Zisserman, A., 2008},
    abstract  = {A 102 category dataset consisting of 102 flower categories, commonly occuring in the United Kingdom. Each class consists of 40 to 258 images. The images have large scale, pose and light variations.}
}
@article{afifi201911k,
    title   = {11K Hands: gender recognition and biometric identification using a large dataset of hand images},
    author  = {Afifi, Mahmoud},
    journal = {Multimedia Tools and Applications}
}
@misc{zhang2018selfattention,
    title   = {Self-Attention Generative Adversarial Networks},
    author  = {Han Zhang and Ian Goodfellow and Dimitris Metaxas and Augustus Odena},
    year    = {2018},
    eprint  = {1805.08318},
    archivePrefix = {arXiv}
}
@article{shen2019efficient,
    author    = {Zhuoran Shen and
               Mingyuan Zhang and
               Haiyu Zhao and
               Shuai Yi and
               Hongsheng Li},
    title     = {Efficient Attention: Attention with Linear Complexities},
    journal   = {CoRR},  
    year      = {2018},
    url       = {http://arxiv.org/abs/1812.01243},
}
@article{zhao2020diffaugment,
    title   = {Differentiable Augmentation for Data-Efficient GAN Training},
    author  = {Zhao, Shengyu and Liu, Zhijian and Lin, Ji and Zhu, Jun-Yan and Han, Song},
    journal = {arXiv preprint arXiv:2006.10738},
    year    = {2020}
}
@misc{zhao2020image,
    title  = {Image Augmentations for GAN Training},
    author = {Zhengli Zhao and Zizhao Zhang and Ting Chen and Sameer Singh and Han Zhang},
    year   = {2020},
    eprint = {2006.02595},
    archivePrefix = {arXiv}
}
@misc{karras2020training,
    title   = {Training Generative Adversarial Networks with Limited Data},
    author  = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
    year    = {2020},
    eprint  = {2006.06676},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{jolicoeurmartineau2018relativistic,
    title   = {The relativistic discriminator: a key element missing from standard GAN},
    author  = {Alexia Jolicoeur-Martineau},
    year    = {2018},
    eprint  = {1807.00734},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{sinha2020topk,
    title   = {Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples},
    author  = {Samarth Sinha and Zhengli Zhao and Anirudh Goyal and Colin Raffel and Augustus Odena},
    year    = {2020},
    eprint  = {2002.06224},
    archivePrefix = {arXiv},
    primaryClass = {stat.ML}
}
@misc{yu2021dual,
    title   = {Dual Contrastive Loss and Attention for GANs},
    author  = {Ning Yu and Guilin Liu and Aysegul Dundar and Andrew Tao and Bryan Catanzaro and Larry Davis and Mario Fritz},
    year    = {2021},
    eprint  = {2103.16748},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

More Repositories

1

vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Python
13,633
star
2

DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
Python
10,770
star
3

imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
Python
7,675
star
4

PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
Python
7,559
star
5

DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
Python
5,132
star
6

deep-daze

Simple command line tool for text to image generation using OpenAI's CLIP and Siren (Implicit neural representation network). Technique was originally created by https://twitter.com/advadnoun
Python
4,387
star
7

denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
Python
3,959
star
8

musiclm-pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch
Python
2,934
star
9

x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
Python
2,707
star
10

big-sleep

A simple command line tool for text to image generation, using OpenAI's CLIP and a BigGAN. Technique was originally created by https://twitter.com/advadnoun
Python
2,446
star
11

audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
Python
2,179
star
12

reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
Python
1,870
star
13

lion-pytorch

🦁 Lion, new optimizer discovered by Google Brain using genetic algorithms that is purportedly better than Adam(w), in Pytorch
Python
1,859
star
14

toolformer-pytorch

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
Python
1,846
star
15

make-a-video-pytorch

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch
Python
1,807
star
16

gigagan-pytorch

Implementation of GigaGAN, new SOTA GAN out of Adobe. Culmination of nearly a decade of research into GANs
Python
1,542
star
17

lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
Python
1,526
star
18

lambda-networks

Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute
Python
1,516
star
19

byol-pytorch

Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch
Python
1,497
star
20

alphafold2

To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released
Python
1,476
star
21

self-rewarding-lm-pytorch

Implementation of the training framework proposed in Self-Rewarding Language Model, from MetaAI
Python
1,154
star
22

naturalspeech2-pytorch

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch
Python
1,141
star
23

flamingo-pytorch

Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch
Python
1,108
star
24

soundstorm-pytorch

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch
Python
1,091
star
25

video-diffusion-pytorch

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch
Python
1,072
star
26

CoCa-pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch
Python
945
star
27

performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
Python
937
star
28

perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
Python
935
star
29

mlp-mixer-pytorch

An All-MLP solution for Vision, from Google AI
Python
833
star
30

RETRO-pytorch

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch
Python
813
star
31

vector-quantize-pytorch

Vector Quantization, in Pytorch
Python
810
star
32

PaLM-pytorch

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways
Python
804
star
33

muse-maskgit-pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
Python
781
star
34

phenaki-pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch
Python
694
star
35

x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers
Python
635
star
36

bottleneck-transformer-pytorch

Implementation of Bottleneck Transformer in Pytorch
Python
632
star
37

TimeSformer-pytorch

Implementation of TimeSformer from Facebook AI, a pure attention-based solution for video classification
Python
613
star
38

memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
Python
596
star
39

MEGABYTE-pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
Python
573
star
40

nuwa-pytorch

Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch
Python
531
star
41

point-transformer-pytorch

Implementation of the Point Transformer layer, in Pytorch
Python
518
star
42

parti-pytorch

Implementation of Parti, Google's pure attention-based text-to-image neural network, in Pytorch
Python
502
star
43

tab-transformer-pytorch

Implementation of TabTransformer, attention network for tabular data, in Pytorch
Python
485
star
44

voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
Python
470
star
45

linear-attention-transformer

Transformer based on a variant of attention that is linear complexity in respect to sequence length
Python
468
star
46

meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
Python
430
star
47

g-mlp-pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Python
391
star
48

siren-pytorch

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
Python
377
star
49

recurrent-memory-transformer-pytorch

Implementation of Recurrent Memory Transformer, Neurips 2022 paper, in Pytorch
Python
371
star
50

egnn-pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch
Python
367
star
51

ema-pytorch

A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
Python
356
star
52

enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
Python
352
star
53

magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
Python
346
star
54

memory-efficient-attention-pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(nΒ²) Memory"
Python
328
star
55

FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
Python
323
star
56

robotic-transformer-pytorch

Implementation of RT1 (Robotic Transformer) in Pytorch
Python
320
star
57

medical-chatgpt

Implementation of ChatGPT, but tailored towards primary care medicine, with the reward being able to collect patient histories in a thorough and efficient manner and come up with a reasonable differential diagnosis
Python
309
star
58

bit-diffusion

Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch
Python
308
star
59

slot-attention

Implementation of Slot Attention from GoogleAI
Python
303
star
60

iTransformer

Unofficial implementation of iTransformer - SOTA Time Series Forecasting using Attention networks, out of Tsinghua / Ant group
Python
300
star
61

transformer-in-transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Python
277
star
62

axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
Python
273
star
63

conformer

Implementation of the convolutional module from the Conformer paper, for use in Transformers
Python
272
star
64

q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
Python
266
star
65

mixture-of-experts

A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
Python
264
star
66

magic3d-pytorch

Implementation of Magic3D, Text to 3D content synthesis, in Pytorch
Python
258
star
67

routing-transformer

Fully featured implementation of Routing Transformer
Python
251
star
68

classifier-free-guidance-pytorch

Implementation of Classifier Free Guidance in Pytorch, with emphasis on text conditioning, and flexibility to include multiple text embedding models
Python
248
star
69

Adan-pytorch

Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch
Python
241
star
70

x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings
Python
241
star
71

deformable-attention

Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"
Python
237
star
72

segformer-pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Python
227
star
73

perfusion-pytorch

Implementation of Key-Locked Rank One Editing, from Nvidia AI
Python
224
star
74

sinkhorn-transformer

Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention
Python
222
star
75

equiformer-pytorch

Implementation of the Equiformer, SE3/E3 equivariant attention network that reaches new SOTA, and adopted for use by EquiFold for protein folding
Python
220
star
76

pixel-level-contrastive-learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch
Python
220
star
77

spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
Python
220
star
78

ring-attention-pytorch

Explorations into Ring Attention, from Liu et al. at Berkeley AI
Python
218
star
79

local-attention

An implementation of local windowed attention for language modeling
Python
216
star
80

natural-speech-pytorch

Implementation of the neural network proposed in Natural Speech, a text-to-speech generator that is indistinguishable from human recordings for the first time, from Microsoft Research
Python
215
star
81

BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
Python
213
star
82

CoLT5-attention

Implementation of the conditionally routed attention in the CoLT5 architecture, in Pytorch
Python
212
star
83

se3-transformer-pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
Python
211
star
84

block-recurrent-transformer-pytorch

Implementation of Block Recurrent Transformer - Pytorch
Python
198
star
85

Mega-pytorch

Implementation of Mega, the Single-head Attention with Multi-headed EMA architecture that currently holds SOTA on Long Range Arena
Python
198
star
86

triton-transformer

Implementation of a Transformer, but completely in Triton
Python
195
star
87

jax2torch

Use Jax functions in Pytorch
Python
194
star
88

halonet-pytorch

Implementation of the πŸ˜‡ Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Python
193
star
89

st-moe-pytorch

Implementation of ST-Moe, the latest incarnation of MoE after years of research at Brain, in Pytorch
Python
190
star
90

flash-cosine-sim-attention

Implementation of fused cosine similarity attention in the same style as Flash Attention
Cuda
190
star
91

attention

This repository will house a visualization that will attempt to convey instant enlightenment of how Attention works to someone not working in artificial intelligence, with 3Blue1Brown as inspiration
HTML
189
star
92

simple-hierarchical-transformer

Experiments around a simple idea for inducing multiple hierarchical predictive model within a GPT
Python
189
star
93

med-seg-diff-pytorch

Implementation of MedSegDiff in Pytorch - SOTA medical segmentation using DDPM and filtering of features in fourier space
Python
187
star
94

electra-pytorch

A simple and working implementation of Electra, the fastest way to pretrain language models from scratch, in Pytorch
Python
186
star
95

recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
Python
185
star
96

unet-stylegan2

A Pytorch implementation of Stylegan2 with UNet Discriminator
Python
182
star
97

res-mlp-pytorch

Implementation of ResMLP, an all MLP solution to image classification, in Pytorch
Python
181
star
98

PaLM-jax

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
Python
180
star
99

glom-pytorch

An attempt at the implementation of Glom, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns), for emergent part-whole heirarchies from data
Python
178
star
100

soft-moe-pytorch

Implementation of Soft MoE, proposed by Brain's Vision team, in Pytorch
Python
174
star