• Stars
    star
    1,536
  • Rank 30,468 (Top 0.7 %)
  • Language
    Python
  • License
    MIT License
  • Created about 4 years ago
  • Updated about 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released

Alphafold2 - Pytorch (wip)

To eventually become an unofficial working Pytorch implementation of Alphafold2, the breathtaking attention network that solved CASP14. Will be gradually implemented as more details of the architecture is released.

Once this is replicated, I intend to fold all available amino acid sequences out there in-silico and release it as an academic torrent, to further science. If you are interested in replication efforts, please drop by #alphafold at this Discord channel

Update: Deepmind has open sourced the official code in Jax, along with the weights πŸ™! This repository will now be geared towards a straight pytorch translation with some improvements on positional encoding

ArxivInsights video

Install

$ pip install alphafold2-pytorch

Status

lhatsk has reported training a modified trunk of this repository, using the same setup as trRosetta, with competitive results

blue used the the trRosetta input (MSA -> potts -> axial attention), green used the ESM embedding (only sequence) -> tiling -> axial attention - lhatsk

Usage

Predicting distogram, like Alphafold-1, but with attention

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    reversible = False  # set this to True for fully reversible self / cross attention for the trunk
).cuda()

seq = torch.randint(0, 21, (1, 128)).cuda()      # AA length of 128
msa = torch.randint(0, 21, (1, 5, 120)).cuda()   # MSA doesn't have to be the same length as primary sequence
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

distogram = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
) # (1, 128, 128, 37)

You can also turn on prediction for the angles, by passing a predict_angles = True on init. The below example would be equivalent to trRosetta but with self / cross attention.

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    predict_angles = True   # set this to True
).cuda()

seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

distogram, theta, phi, omega = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
)

# distogram - (1, 128, 128, 37),
# theta     - (1, 128, 128, 25),
# phi       - (1, 128, 128, 13),
# omega     - (1, 128, 128, 25)

Predicting Coordinates

Fabian's recent paper suggests iteratively feeding the coordinates back into SE3 Transformer, weight shared, may work. I have decided to execute based on this idea, even though it is still up in the air how it actually works.

You can also use E(n)-Transformer or EGNN for structural refinement.

Update: Baker's lab have shown that an end-to-end architecture from sequence and MSA embeddings to SE3 Transformers can best trRosetta and close the gap to Alphafold2. We will be using the Graph Transformer, which acts on the trunk embeddings, to generate the initial set of coordinates to be sent to the equivariant network. (This is further corroborated by Costa et al in their work teasing out 3d coordinates from MSA Transformer embeddings in a paper predating Baker lab's)

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    predict_coords = True,
    structure_module_type = 'se3',          # use SE3 Transformer - if set to False, will use E(n)-Transformer, Victor and Max Welling's new paper
    structure_module_dim = 4,               # se3 transformer dimension
    structure_module_depth = 1,             # depth
    structure_module_heads = 1,             # heads
    structure_module_dim_head = 16,         # dimension of heads
    structure_module_refinement_iters = 2,  # number of equivariant coordinate refinement iterations
    structure_num_global_nodes = 1          # number of global nodes for the structure module, only works with SE3 transformer
).cuda()

seq = torch.randint(0, 21, (2, 64)).cuda()
msa = torch.randint(0, 21, (2, 5, 60)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

coords = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
) # (2, 64 * 3, 3)  <-- 3 atoms per residue

Atoms

The underlying assumption is that the trunk works on the residue level, and then constitutes to atomic level for the structure module, whether it be SE3 Transformers, E(n)-Transformer, or EGNN doing the refinement. This library defaults to the 3 backbone atoms (C, Ca, N), but you can configure it to include any other atom you like, including Cb and the sidechains.

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    predict_coords = True,
    atoms = 'backbone-with-cbeta'
).cuda()

seq = torch.randint(0, 21, (2, 64)).cuda()
msa = torch.randint(0, 21, (2, 5, 60)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

coords = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
) # (2, 64 * 4, 3)  <-- 4 atoms per residue (C, Ca, N, Cb)

Valid choices for atoms include:

  • backbone - 3 backbone atoms (C, Ca, N) [default]
  • backbone-with-cbeta - 3 backbone atoms and C beta
  • backbone-with-oxygen - 3 backbone atoms and oxygen from carboxyl
  • backbone-with-cbeta-and-oxygen - 3 backbone atoms with C beta and oxygen
  • all - backbone and all other atoms from sidechain

You can also pass in a tensor of shape (14,) defining which atoms you would like to include

ex.

atoms = torch.tensor([1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1])

MSA, ESM, or ProtTrans Embeddings

This repository offers you an easy supplement the network with pre-trained embeddings from Facebook AI. It contains wrappers for the pre-trained ESM, MSA Transformers or Protein Transformer.

There are some prerequisites. You will need to make sure that you have Nvidia's apex library installed, as the pretrained transformers make use of some fused operations.

Or you can try running the script below

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Next, you will simply have to import and wrap your Alphafold2 instance with a ESMEmbedWrapper, MSAEmbedWrapper, or ProtTranEmbedWrapper and it will take care of embedding both the sequence and the multiple-sequence alignments for you (and projecting it to the dimensions as specified on your model). Nothing needs to be changed save for adding the wrapper.

import torch
from alphafold2_pytorch import Alphafold2
from alphafold2_pytorch.embeds import MSAEmbedWrapper

alphafold2 = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64
)

model = MSAEmbedWrapper(
    alphafold2 = alphafold2
).cuda()

seq = torch.randint(0, 21, (2, 16)).cuda()
mask = torch.ones_like(seq).bool().cuda()

msa = torch.randint(0, 21, (2, 5, 16)).cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

distogram = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
)

By default, even if the wrapper supplies the trunk with the sequence and MSA embeddings, they would be summed with the usual token embeddings. If you want to train Alphafold2 without token embeddings (only rely on pretrained embeddings), you would need to set disable_token_embed to True on Alphafold2 init.

alphafold2 = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    disable_token_embed = True
)

Real-Value Distance Prediction

A paper by Jinbo Xu suggests that one doesn't need to bin the distances, and can instead predict the mean and standard deviation directly. You can use this by turning on one flag predict_real_value_distances, in which case, the distance prediction returned will have a dimension of 2 for the mean and standard deviation respectively.

If predict_coords is also turned on, then the MDS will accept the mean and standard deviation predictions directly without having to calculate that from the distogram bins.

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    predict_coords = True,
    predict_real_value_distances = True,      # set this to True
    structure_module_type = 'se3',
    structure_module_dim = 4,
    structure_module_depth = 1,
    structure_module_heads = 1,
    structure_module_dim_head = 16,
    structure_module_refinement_iters = 2
).cuda()

seq = torch.randint(0, 21, (2, 64)).cuda()
msa = torch.randint(0, 21, (2, 5, 60)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

coords = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
) # (2, 64 * 3, 3)  <-- 3 atoms per residue

Convolutions

You can add convolutional blocks, for both the primary sequence as well as the MSA, by simply setting one extra keyword argument use_conv = True

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    use_conv = True # set this to True
).cuda()

seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

distogram = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
) # (1, 128, 128, 37)

The convolutional kernels follow the lead of this paper, combining 1d and 2d kernels in one resnet-like block. You can fully customize the kernels as such.

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    use_conv = True, # set this to True
    conv_seq_kernels = ((9, 1), (1, 9), (3, 3)), # kernels for N x N primary sequence
    conv_msa_kernels = ((1, 9), (3, 3)), # kernels for {num MSAs} x N MSAs
).cuda()

seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

distogram = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
) # (1, 128, 128, 37)

You can also do cycle dilation with one extra keyword argument. Default dilation is 1 for all layers.

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    use_conv = True, # set this to True
    dilations = (1, 3, 5) # cycle between dilations of 1, 3, 5
).cuda()

seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

distogram = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
) # (1, 128, 128, 37)

Finally, instead of following the pattern of convolutions, self-attention, cross-attention per depth repeating, you can customize any order you wish with the custom_block_types keyword

ex. A network where you do predominately convolutions first, followed by self-attention + cross-attention blocks

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    heads = 8,
    dim_head = 64,
    custom_block_types = (
        *(('conv',) * 6),
        *(('self', 'cross') * 6)
    )
).cuda()

seq = torch.randint(0, 21, (1, 128)).cuda()
msa = torch.randint(0, 21, (1, 5, 120)).cuda()
mask = torch.ones_like(seq).bool().cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

distogram = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask
) # (1, 128, 128, 37)

Sparse Attention

You can train with Microsoft Deepspeed's Sparse Attention, but you will have to endure the installation process. It is two-steps.

First, you need to install Deepspeed with Sparse Attention

$ sh install_deepspeed.sh

Next, you need to install the pip package triton

$ pip install triton

If both of the above succeeded, now you can train with Sparse Attention!

Sadly, the sparse attention is only supported for self attention, and not cross attention. I will bring in a different solution for making cross attention performant.

model = Alphafold2(
    dim = 256,
    depth = 12,
    heads = 8,
    dim_head = 64,
    max_seq_len = 2048,                   # the maximum sequence length, this is required for sparse attention. the input cannot exceed what is set here
    sparse_self_attn = (True, False) * 6  # interleave sparse and full attention for all 12 layers
).cuda()

Linear Attention

I have also added one of the best linear attention variants, in the hope of lessening the burden of cross attending. I personally have not found Performer to work that well, but since in the paper they reported some ok numbers for protein benchmarks, I thought I'd include it and allow others to experiment.

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    cross_attn_linear = True # simply set this to True to use Performer for all cross attention
).cuda()

You can also specify the exact layers you wish to use linear attention by passing in a tuple of the same length as the depth

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 6,
    heads = 8,
    dim_head = 64,
    cross_attn_linear = (True, False) * 3 # interleave linear and full attention
).cuda()

Kronecker Attention for Cross Attention

This paper suggests that if you have queries or contexts that have defined axials (say an image), you can reduce the amount of attention needed by averaging across those axials (height and width) and concatenating the averaged axials into one sequence. You can turn this on as a memory saving technique for the cross attention, specifically for the primary sequence.

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 6,
    heads = 8,
    dim_head = 64,
    cross_attn_kron_primary = True # make sure primary sequence undergoes the kronecker operator during cross attention
).cuda()

You can also apply the same operator to the MSAs during cross attention with the cross_attn_kron_msa flag, if your MSAs are aligned and of the same width.

Todo

  • offer masked mean reduction method
  • rotary embeddings

Memory Compressed Attention

To save on memory for cross attention, you can set a compression ratio for the key / values, following the scheme laid out in this paper. A compression ratio of 2-4 is usually acceptable.

model = Alphafold2(
    dim = 256,
    depth = 12,
    heads = 8,
    dim_head = 64,
    cross_attn_compress_ratio = 3
).cuda()

MSA processing in Trunk

A new paper by Roshan Rao proposes using axial attention for pretraining on MSA's. Given the strong results, this repository will use the same scheme in the trunk, specifically for the MSA self-attention.

You can also tie the row attentions of the MSA with the msa_tie_row_attn = True setting on initialization of Alphafold2. However, in order to use this, you must make sure that if you have uneven number of MSAs per primary sequence, that the MSA mask is properly set to False for the rows not in use.

model = Alphafold2(
    dim = 256,
    depth = 2,
    heads = 8,
    dim_head = 64,
    msa_tie_row_attn = True # just set this to true
)

Template processing in Trunk

Template processing is also largely done with axial attention, with cross attention done along the number of templates dimension. This largely follows the same scheme as in the recent all-attention approach to video classification as shown here.

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 5,
    heads = 8,
    dim_head = 64,
    reversible = True,
    sparse_self_attn = False,
    max_seq_len = 256,
    cross_attn_compress_ratio = 3
).cuda()

seq = torch.randint(0, 21, (1, 16)).cuda()
mask = torch.ones_like(seq).bool().cuda()

msa = torch.randint(0, 21, (1, 10, 16)).cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

templates_seq = torch.randint(0, 21, (1, 2, 16)).cuda()
templates_coors = torch.randint(0, 37, (1, 2, 16, 3)).cuda()
templates_mask = torch.ones_like(templates_seq).bool().cuda()

distogram = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask,
    templates_seq = templates_seq,
    templates_coors = templates_coors,
    templates_mask = templates_mask
)

If sidechain information is also present, in the form of the unit vector between the C and C-alpha coordinates of each residue, you can also pass it in as follows.

import torch
from alphafold2_pytorch import Alphafold2

model = Alphafold2(
    dim = 256,
    depth = 5,
    heads = 8,
    dim_head = 64,
    reversible = True,
    sparse_self_attn = False,
    max_seq_len = 256,
    cross_attn_compress_ratio = 3
).cuda()

seq = torch.randint(0, 21, (1, 16)).cuda()
mask = torch.ones_like(seq).bool().cuda()

msa = torch.randint(0, 21, (1, 10, 16)).cuda()
msa_mask = torch.ones_like(msa).bool().cuda()

templates_seq = torch.randint(0, 21, (1, 2, 16)).cuda()
templates_coors = torch.randn(1, 2, 16, 3).cuda()
templates_mask = torch.ones_like(templates_seq).bool().cuda()

templates_sidechains = torch.randn(1, 2, 16, 3).cuda() # unit vectors of difference of C and C-alpha coordinates

distogram = model(
    seq,
    msa,
    mask = mask,
    msa_mask = msa_mask,
    templates_seq = templates_seq,
    templates_mask = templates_mask,
    templates_coors = templates_coors,
    templates_sidechains = templates_sidechains
)

Equivariant Attention

I have prepared a reimplementation of SE3 Transformer, as explained by Fabian Fuchs in a speculatory blogpost.

In addition, a new paper from Victor and Welling uses invariant features for E(n) equivariance, reaching SOTA and outperforming SE3 Transformer at a number of benchmarks, while being much faster. I have taken the main ideas from this paper and modified it to become a transformer (added attention to both features and coordinate updates).

All three of the equivariant networks above have been integrated and are available for use in the repository for atomic coordinate refinement by simply setting one hyperparameter structure_module_type.

Of interest to readers, each of the three frameworks have also been validated by researchers on related problems.

Testing

$ python setup.py test

Data

This library will use the awesome work by Jonathan King at this repository. Thank you Jonathan πŸ™!

We also have the MSA data, all ~3.5 TB worth, downloaded and hosted by Archivist, who owns The-Eye project. (They also host the data and models for Eleuther AI) Please consider a donation if you find them helpful.

$ curl -s https://the-eye.eu/eleuther_staging/globus_stuffs/tree.txt

Speculation

https://xukui.cn/alphafold2.html

https://moalquraishi.wordpress.com/2020/12/08/alphafold2-casp14-it-feels-like-ones-child-has-left-home/

Recent works by competing labs

https://www.biorxiv.org/content/10.1101/2020.12.10.419994v1.full.pdf

https://pubmed.ncbi.nlm.nih.gov/33637700/

tFold presentation, from Tencent AI labs

External packages

  • Final step - Fast Relax - Installation Instructions:
    • Download the pyrosetta wheel from: http://www.pyrosetta.org/dow (select appropiate version) - beware the file is heavy (approx 1.2 Gb)
      • The download should be free for anyone with an academic email
    • Bash > cd downloads_folder > pip install pyrosetta_wheel_filename.whl

OpenMM Amber

Citations

@misc{unpublished2021alphafold2,
    title   = {Alphafold2},
    author  = {John Jumper},
    year    = {2020},
    archivePrefix = {arXiv},
    primaryClass = {q-bio.BM}
}
@article{Rao2021.02.12.430858,
    author  = {Rao, Roshan and Liu, Jason and Verkuil, Robert and Meier, Joshua and Canny, John F. and Abbeel, Pieter and Sercu, Tom and Rives, Alexander},
    title   = {MSA Transformer},
    year    = {2021},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/02/13/2021.02.12.430858},
    journal = {bioRxiv}
}
@article {Rives622803,
    author  = {Rives, Alexander and Goyal, Siddharth and Meier, Joshua and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob},
    title   = {Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences},
    year    = {2019},
    doi     = {10.1101/622803},
    publisher = {Cold Spring Harbor Laboratory},
    journal = {bioRxiv}
}
@article {Elnaggar2020.07.12.199554,
    author  = {Elnaggar, Ahmed and Heinzinger, Michael and Dallago, Christian and Rehawi, Ghalia and Wang, Yu and Jones, Llion and Gibbs, Tom and Feher, Tamas and Angerer, Christoph and Steinegger, Martin and BHOWMIK, DEBSINDHU and Rost, Burkhard},
    title   = {ProtTrans: Towards Cracking the Language of Life{\textquoteright}s Code Through Self-Supervised Deep Learning and High Performance Computing},
    elocation-id = {2020.07.12.199554},
    year    = {2021},
    doi     = {10.1101/2020.07.12.199554},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/05/04/2020.07.12.199554},
    eprint  = {https://www.biorxiv.org/content/early/2021/05/04/2020.07.12.199554.full.pdf},
    journal = {bioRxiv}
}
@misc{king2020sidechainnet,
    title   = {SidechainNet: An All-Atom Protein Structure Dataset for Machine Learning}, 
    author  = {Jonathan E. King and David Ryan Koes},
    year    = {2020},
    eprint  = {2010.08162},
    archivePrefix = {arXiv},
    primaryClass = {q-bio.BM}
}
@misc{alquraishi2019proteinnet,
    title   = {ProteinNet: a standardized data set for machine learning of protein structure}, 
    author  = {Mohammed AlQuraishi},
    year    = {2019},
    eprint  = {1902.00249},
    archivePrefix = {arXiv},
    primaryClass = {q-bio.BM}
}
@misc{gomez2017reversible,
    title     = {The Reversible Residual Network: Backpropagation Without Storing Activations}, 
    author    = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},
    year      = {2017},
    eprint    = {1707.04585},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{fuchs2021iterative,
    title   = {Iterative SE(3)-Transformers},
    author  = {Fabian B. Fuchs and Edward Wagstaff and Justas Dauparas and Ingmar Posner},
    year    = {2021},
    eprint  = {2102.13419},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{satorras2021en,
    title   = {E(n) Equivariant Graph Neural Networks}, 
    author  = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year    = {2021},
    eprint  = {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{Gao_2020,
    title   = {Kronecker Attention Networks},
    ISBN    = {9781450379984},
    url     = {http://dx.doi.org/10.1145/3394486.3403065},
    DOI     = {10.1145/3394486.3403065},
    journal = {Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining},
    publisher = {ACM},
    author  = {Gao, Hongyang and Wang, Zhengyang and Ji, Shuiwang},
    year    = {2020},
    month   = {Jul}
}
@article {Si2021.05.10.443415,
    author  = {Si, Yunda and Yan, Chengfei},
    title   = {Improved protein contact prediction using dimensional hybrid residual networks and singularity enhanced loss function},
    elocation-id = {2021.05.10.443415},
    year    = {2021},
    doi     = {10.1101/2021.05.10.443415},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/05/11/2021.05.10.443415},
    eprint  = {https://www.biorxiv.org/content/early/2021/05/11/2021.05.10.443415.full.pdf},
    journal = {bioRxiv}
}
@article {Costa2021.06.02.446809,
    author  = {Costa, Allan and Ponnapati, Manvitha and Jacobson, Joseph M. and Chatterjee, Pranam},
    title   = {Distillation of MSA Embeddings to Folded Protein Structures with Graph Transformers},
    year    = {2021},
    doi     = {10.1101/2021.06.02.446809},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809},
    eprint  = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809.full.pdf},
    journal = {bioRxiv}
}
@article {Baek2021.06.14.448402,
    author  = {Baek, Minkyung and DiMaio, Frank and Anishchenko, Ivan and Dauparas, Justas and Ovchinnikov, Sergey and Lee, Gyu Rie and Wang, Jue and Cong, Qian and Kinch, Lisa N. and Schaeffer, R. Dustin and Mill{\'a}n, Claudia and Park, Hahnbeom and Adams, Carson and Glassman, Caleb R. and DeGiovanni, Andy and Pereira, Jose H. and Rodrigues, Andria V. and van Dijk, Alberdina A. and Ebrecht, Ana C. and Opperman, Diederik J. and Sagmeister, Theo and Buhlheller, Christoph and Pavkov-Keller, Tea and Rathinaswamy, Manoj K and Dalwadi, Udit and Yip, Calvin K and Burke, John E and Garcia, K. Christopher and Grishin, Nick V. and Adams, Paul D. and Read, Randy J. and Baker, David},
    title   = {Accurate prediction of protein structures and interactions using a 3-track network},
    year    = {2021},
    doi     = {10.1101/2021.06.14.448402},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402},
    eprint  = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402.full.pdf},
    journal = {bioRxiv}
}

More Repositories

1

vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Python
13,633
star
2

DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
Python
11,068
star
3

imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
Python
7,832
star
4

PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
Python
7,611
star
5

DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
Python
5,132
star
6

deep-daze

Simple command line tool for text to image generation using OpenAI's CLIP and Siren (Implicit neural representation network). Technique was originally created by https://twitter.com/advadnoun
Python
4,387
star
7

denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
Python
3,959
star
8

stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
Python
3,433
star
9

musiclm-pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch
Python
3,048
star
10

x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
Python
2,707
star
11

big-sleep

A simple command line tool for text to image generation, using OpenAI's CLIP and a BigGAN. Technique was originally created by https://twitter.com/advadnoun
Python
2,446
star
12

audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
Python
2,285
star
13

lion-pytorch

🦁 Lion, new optimizer discovered by Google Brain using genetic algorithms that is purportedly better than Adam(w), in Pytorch
Python
1,933
star
14

toolformer-pytorch

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
Python
1,905
star
15

reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
Python
1,870
star
16

make-a-video-pytorch

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch
Python
1,853
star
17

gigagan-pytorch

Implementation of GigaGAN, new SOTA GAN out of Adobe. Culmination of nearly a decade of research into GANs
Python
1,632
star
18

lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
Python
1,526
star
19

lambda-networks

Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute
Python
1,516
star
20

byol-pytorch

Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch
Python
1,497
star
21

self-rewarding-lm-pytorch

Implementation of the training framework proposed in Self-Rewarding Language Model, from MetaAI
Python
1,253
star
22

naturalspeech2-pytorch

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch
Python
1,214
star
23

flamingo-pytorch

Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch
Python
1,155
star
24

video-diffusion-pytorch

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch
Python
1,141
star
25

soundstorm-pytorch

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch
Python
1,130
star
26

CoCa-pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch
Python
990
star
27

performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
Python
937
star
28

perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
Python
935
star
29

RETRO-pytorch

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch
Python
835
star
30

mlp-mixer-pytorch

An All-MLP solution for Vision, from Google AI
Python
833
star
31

muse-maskgit-pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
Python
821
star
32

PaLM-pytorch

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways
Python
812
star
33

vector-quantize-pytorch

Vector Quantization, in Pytorch
Python
810
star
34

phenaki-pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch
Python
724
star
35

x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers
Python
658
star
36

bottleneck-transformer-pytorch

Implementation of Bottleneck Transformer in Pytorch
Python
632
star
37

memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
Python
614
star
38

TimeSformer-pytorch

Implementation of TimeSformer from Facebook AI, a pure attention-based solution for video classification
Python
613
star
39

MEGABYTE-pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
Python
594
star
40

meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
Python
564
star
41

nuwa-pytorch

Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch
Python
531
star
42

voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
Python
521
star
43

point-transformer-pytorch

Implementation of the Point Transformer layer, in Pytorch
Python
518
star
44

parti-pytorch

Implementation of Parti, Google's pure attention-based text-to-image neural network, in Pytorch
Python
509
star
45

tab-transformer-pytorch

Implementation of TabTransformer, attention network for tabular data, in Pytorch
Python
485
star
46

alphafold3-pytorch

Implementation of Alphafold 3 in Pytorch
Python
483
star
47

linear-attention-transformer

Transformer based on a variant of attention that is linear complexity in respect to sequence length
Python
468
star
48

magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
Python
436
star
49

ema-pytorch

A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
Python
408
star
50

egnn-pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch
Python
400
star
51

g-mlp-pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Python
391
star
52

recurrent-memory-transformer-pytorch

Implementation of Recurrent Memory Transformer, Neurips 2022 paper, in Pytorch
Python
384
star
53

ring-attention-pytorch

Implementation of πŸ’ Ring Attention, from Liu et al. at Berkeley AI, in Pytorch
Python
380
star
54

siren-pytorch

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
Python
377
star
55

enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
Python
352
star
56

iTransformer

Unofficial implementation of iTransformer - SOTA Time Series Forecasting using Attention networks, out of Tsinghua / Ant group
Python
349
star
57

robotic-transformer-pytorch

Implementation of RT1 (Robotic Transformer) in Pytorch
Python
346
star
58

memory-efficient-attention-pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(nΒ²) Memory"
Python
342
star
59

FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
Python
334
star
60

bit-diffusion

Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch
Python
313
star
61

medical-chatgpt

Implementation of ChatGPT, but tailored towards primary care medicine, with the reward being able to collect patient histories in a thorough and efficient manner and come up with a reasonable differential diagnosis
Python
311
star
62

slot-attention

Implementation of Slot Attention from GoogleAI
Python
303
star
63

q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
Python
293
star
64

BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
Python
289
star
65

classifier-free-guidance-pytorch

Implementation of Classifier Free Guidance in Pytorch, with emphasis on text conditioning, and flexibility to include multiple text embedding models
Python
282
star
66

transformer-in-transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Python
277
star
67

axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
Python
273
star
68

conformer

Implementation of the convolutional module from the Conformer paper, for use in Transformers
Python
272
star
69

mixture-of-experts

A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
Python
264
star
70

deformable-attention

Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"
Python
258
star
71

magic3d-pytorch

Implementation of Magic3D, Text to 3D content synthesis, in Pytorch
Python
258
star
72

x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings
Python
252
star
73

routing-transformer

Fully featured implementation of Routing Transformer
Python
251
star
74

Adan-pytorch

Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch
Python
245
star
75

spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
Python
241
star
76

st-moe-pytorch

Implementation of ST-Moe, the latest incarnation of MoE after years of research at Brain, in Pytorch
Python
237
star
77

perfusion-pytorch

Implementation of Key-Locked Rank One Editing, from Nvidia AI
Python
229
star
78

equiformer-pytorch

Implementation of the Equiformer, SE3/E3 equivariant attention network that reaches new SOTA, and adopted for use by EquiFold for protein folding
Python
227
star
79

segformer-pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Python
227
star
80

sinkhorn-transformer

Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention
Python
222
star
81

pixel-level-contrastive-learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch
Python
220
star
82

lumiere-pytorch

Implementation of Lumiere, SOTA text-to-video generation from Google Deepmind, in Pytorch
Python
216
star
83

local-attention

An implementation of local windowed attention for language modeling
Python
216
star
84

CoLT5-attention

Implementation of the conditionally routed attention in the CoLT5 architecture, in Pytorch
Python
216
star
85

natural-speech-pytorch

Implementation of the neural network proposed in Natural Speech, a text-to-speech generator that is indistinguishable from human recordings for the first time, from Microsoft Research
Python
215
star
86

soft-moe-pytorch

Implementation of Soft MoE, proposed by Brain's Vision team, in Pytorch
Python
211
star
87

se3-transformer-pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
Python
211
star
88

block-recurrent-transformer-pytorch

Implementation of Block Recurrent Transformer - Pytorch
Python
205
star
89

Mega-pytorch

Implementation of Mega, the Single-head Attention with Multi-headed EMA architecture that currently holds SOTA on Long Range Arena
Python
201
star
90

simple-hierarchical-transformer

Experiments around a simple idea for inducing multiple hierarchical predictive model within a GPT
Python
198
star
91

med-seg-diff-pytorch

Implementation of MedSegDiff in Pytorch - SOTA medical segmentation using DDPM and filtering of features in fourier space
Python
195
star
92

triton-transformer

Implementation of a Transformer, but completely in Triton
Python
195
star
93

jax2torch

Use Jax functions in Pytorch
Python
194
star
94

flash-cosine-sim-attention

Implementation of fused cosine similarity attention in the same style as Flash Attention
Cuda
194
star
95

halonet-pytorch

Implementation of the πŸ˜‡ Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Python
193
star
96

attention

This repository will house a visualization that will attempt to convey instant enlightenment of how Attention works to someone not working in artificial intelligence, with 3Blue1Brown as inspiration
HTML
189
star
97

recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
Python
188
star
98

electra-pytorch

A simple and working implementation of Electra, the fastest way to pretrain language models from scratch, in Pytorch
Python
186
star
99

PaLM-jax

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
Python
184
star
100

unet-stylegan2

A Pytorch implementation of Stylegan2 with UNet Discriminator
Python
182
star