• Stars
    star
    236
  • Rank 164,381 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created about 2 years ago
  • Updated about 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Vision Transformer Cookbook with Tensorflow

Vision Transformer Cookbook with Tensorflow

Author

Acknowledgement

Table of Contents

Vision Transformer - Tensorflow ( >= 2.3.0)

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Tensorflow. Significance is further explained in Yannic Kilcher's video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.

Usage

import tensorflow as tf
from vit_tensorflow import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = tf.random.normal([1, 256, 256, 3])

preds = v(img) # (1, 1000)

Parameters

  • image_size: int.
    Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
  • patch_size: int.
    Number of patches. image_size must be divisible by patch_size.
    The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.
  • num_classes: int.
    Number of classes to classify.
  • dim: int.
    Last dimension of output tensor after linear transformation nn.Linear(..., dim).
  • depth: int.
    Number of Transformer blocks.
  • heads: int.
    Number of heads in Multi-head Attention layer.
  • mlp_dim: int.
    Dimension of the MLP (FeedForward) layer.
  • dropout: float between [0, 1], default 0..
    Dropout rate.
  • emb_dropout: float between [0, 1], default 0.
    Embedding dropout rate.
  • pool: string, either cls token pooling or mean pooling

Distillation

A recent paper has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.

ex. distilling from Resnet50 (or any teacher) to a vision transformer

import tensorflow as tf
<<<<<<< HEAD

=======
>>>>>>> 4d94a87a458fa952a88f56d1e188eef5524a895a
from vit_tensorflow.distill import DistillableViT, DistillWrapper

teacher = tf.keras.applications.resnet50.ResNet50()

v = DistillableViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

distiller = DistillWrapper(
    student = v,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)

img = tf.random.normal([2, 256, 256, 3])
labels = tf.random.uniform(shape=[2, ], minval=0, maxval=1000, dtype=tf.int32)
labels = tf.one_hot(labels, depth=1000, axis=-1)

loss = distiller([img, labels])

# after lots of training above ...

pred = v(img) # (2, 1000)

Deep ViT

This paper notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the Talking Heads paper from NLP.

You can use it as follows

import tensorflow as tf
from vit_tensorflow.deepvit import DeepViT

v = DeepViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = tf.random.normal([1, 256, 256, 3])

preds = v(img) # (1, 1000)

CaiT

This paper also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.

They also add Talking Heads, noting improvements

You can use this scheme as follows

import tensorflow as tf
from vit_tensorflow.cait import CaiT

v = CaiT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 12,             # depth of transformer for patch to patch attention only
    cls_depth = 2,          # depth of cross attention of CLS tokens to patch
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    layer_dropout = 0.05    # randomly dropout 5% of the layers
)

img = tf.random.normal([1, 256, 256, 3])

preds = v(img) # (1, 1000)

Token-to-Token ViT

This paper proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the ViT as follows.

import tensorflow as tf
from vit_tensorflow.t2t import T2TViT

v = T2TViT(
    dim = 512,
    image_size = 224,
    depth = 5,
    heads = 8,
    mlp_dim = 512,
    num_classes = 1000,
    t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module
)

img = tf.random.normal([1, 224, 224, 3])

preds = v(img) # (1, 1000)

CCT

CCT proposes compact transformers by using convolutions instead of patching and performing sequence pooling. This allows for CCT to have high accuracy and a low number of parameters.

You can use this with two methods

import tensorflow as tf
from vit_tensorflow.cct import CCT

<<<<<<< HEAD
model = CCT(
        img_size=224,
        embedding_dim=384,
        n_conv_layers=2,
        kernel_size=7,
        stride=2,
        padding=3,
        pooling_kernel_size=3,
        pooling_stride=2,
        pooling_padding=1,
        num_layers=14,
        num_heads=6,
        mlp_radio=3.,
        num_classes=1000,
        positional_embedding='learnable', # ['sine', 'learnable', 'none']
        )
=======
cct = CCT(
    img_size = (224, 448),
    embedding_dim = 384,
    n_conv_layers = 2,
    kernel_size = 7,
    stride = 2,
    padding = 3,
    pooling_kernel_size = 3,
    pooling_stride = 2,
    pooling_padding = 1,
    num_layers = 14,
    num_heads = 6,
    mlp_radio = 3.,
    num_classes = 1000,
    positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)

img = tf.random.normal(shape=[1, 224, 448, 3])
preds = cct(img) # (1, 1000)

>>>>>>> 4d94a87a458fa952a88f56d1e188eef5524a895a

Alternatively you can use one of several pre-defined models [2,4,6,7,8,14,16] which pre-define the number of layers, number of attention heads, the mlp ratio, and the embedding dimension.

import tensorflow as tf
from vit_tensorflow.cct import cct_14

<<<<<<< HEAD
model = cct_14(
        img_size=224,
        n_conv_layers=1,
        kernel_size=7,
        stride=2,
        padding=3,
        pooling_kernel_size=3,
        pooling_stride=2,
        pooling_padding=1,
        num_classes=1000,
        positional_embedding='learnable', # ['sine', 'learnable', 'none']  
        )
=======
cct = cct_14(
    img_size = 224,
    n_conv_layers = 1,
    kernel_size = 7,
    stride = 2,
    padding = 3,
    pooling_kernel_size = 3,
    pooling_stride = 2,
    pooling_padding = 1,
    num_classes = 1000,
    positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)
>>>>>>> 4d94a87a458fa952a88f56d1e188eef5524a895a

Official Repository includes links to pretrained model checkpoints.

Cross ViT

This paper proposes to have two vision transformers processing the image at different scales, cross attending to one every so often. They show improvements on top of the base vision transformer.

import tensorflow as tf
from vit_tensorflow.cross_vit import CrossViT

v = CrossViT(
    image_size = 256,
    num_classes = 1000,
    depth = 4,               # number of multi-scale encoding blocks
    sm_dim = 192,            # high res dimension
    sm_patch_size = 16,      # high res patch size (should be smaller than lg_patch_size)
    sm_enc_depth = 2,        # high res depth
    sm_enc_heads = 8,        # high res heads
    sm_enc_mlp_dim = 2048,   # high res feedforward dimension
    lg_dim = 384,            # low res dimension
    lg_patch_size = 64,      # low res patch size
    lg_enc_depth = 3,        # low res depth
    lg_enc_heads = 8,        # low res heads
    lg_enc_mlp_dim = 2048,   # low res feedforward dimensions
    cross_attn_depth = 2,    # cross attention rounds
    cross_attn_heads = 8,    # cross attention heads
    dropout = 0.1,
    emb_dropout = 0.1
)

img = tf.random.normal([1, 256, 256, 3])

pred = v(img) # (1, 1000)

PiT

This paper proposes to downsample the tokens through a pooling procedure using depth-wise convolutions.

import tensorflow as tf
from vit_tensorflow.pit import PiT

v = PiT(
    image_size = 224,
    patch_size = 14,
    dim = 256,
    num_classes = 1000,
    depth = (3, 3, 3),     # list of depths, indicating the number of rounds of each stage before a downsample
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# forward pass now returns predictions and the attention maps

img = tf.random.normal([1, 224, 224, 3])

preds = v(img) # (1, 1000)

LeViT

This paper proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.

Official repository

import tensorflow as tf
from vit_tensorflow.levit import LeViT

levit = LeViT(
    image_size = 224,
    num_classes = 1000,
    stages = 3,             # number of stages
    dim = (256, 384, 512),  # dimensions at each stage
    depth = 4,              # transformer of depth 4 at each stage
    heads = (4, 6, 8),      # heads at each stage
    mlp_mult = 2,
    dropout = 0.1
)

img = tf.random.normal([1, 224, 224, 3])

levit(img) # (1, 1000)

CvT

This paper proposes mixing convolutions and attention. Specifically, convolutions are used to embed and downsample the image / feature map in three stages. Depthwise-convoltion is also used to project the queries, keys, and values for attention.

import tensorflow as tf
from vit_tensorflow.cvt import CvT

v = CvT(
    num_classes = 1000,
    s1_emb_dim = 64,        # stage 1 - dimension
    s1_emb_kernel = 7,      # stage 1 - conv kernel
    s1_emb_stride = 4,      # stage 1 - conv stride
    s1_proj_kernel = 3,     # stage 1 - attention ds-conv kernel size
    s1_kv_proj_stride = 2,  # stage 1 - attention key / value projection stride
    s1_heads = 1,           # stage 1 - heads
    s1_depth = 1,           # stage 1 - depth
    s1_mlp_mult = 4,        # stage 1 - feedforward expansion factor
    s2_emb_dim = 192,       # stage 2 - (same as above)
    s2_emb_kernel = 3,
    s2_emb_stride = 2,
    s2_proj_kernel = 3,
    s2_kv_proj_stride = 2,
    s2_heads = 3,
    s2_depth = 2,
    s2_mlp_mult = 4,
    s3_emb_dim = 384,       # stage 3 - (same as above)
    s3_emb_kernel = 3,
    s3_emb_stride = 2,
    s3_proj_kernel = 3,
    s3_kv_proj_stride = 2,
    s3_heads = 4,
    s3_depth = 10,
    s3_mlp_mult = 4,
    dropout = 0.
)

img = tf.random.normal([1, 224, 224, 3])

pred = v(img) # (1, 1000)

Twins SVT

This paper proposes mixing local and global attention, along with position encoding generator (proposed in CPVT) and global average pooling, to achieve the same results as Swin, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.

import tensorflow as tf
from vit_tensorflow.twins_svt import TwinsSVT

model = TwinsSVT(
    num_classes = 1000,       # number of output classes
    s1_emb_dim = 64,          # stage 1 - patch embedding projected dimension
    s1_patch_size = 4,        # stage 1 - patch size for patch embedding
    s1_local_patch_size = 7,  # stage 1 - patch size for local attention
    s1_global_k = 7,          # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
    s1_depth = 1,             # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
    s2_emb_dim = 128,         # stage 2 (same as above)
    s2_patch_size = 2,
    s2_local_patch_size = 7,
    s2_global_k = 7,
    s2_depth = 1,
    s3_emb_dim = 256,         # stage 3 (same as above)
    s3_patch_size = 2,
    s3_local_patch_size = 7,
    s3_global_k = 7,
    s3_depth = 5,
    s4_emb_dim = 512,         # stage 4 (same as above)
    s4_patch_size = 2,
    s4_local_patch_size = 7,
    s4_global_k = 7,
    s4_depth = 4,
    peg_kernel_size = 3,      # positional encoding generator kernel size
    dropout = 0.              # dropout
)

img = tf.random.normal([1, 224, 224, 3])

pred = model(img) # (1, 1000)

RegionViT

This paper proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.

You can use it as follows

import tensorflow as tf
from vit_tensorflow.regionvit import RegionViT

model = RegionViT(
    dim = (64, 128, 256, 512),      # tuple of size 4, indicating dimension at each stage
    depth = (2, 2, 8, 2),           # depth of the region to local transformer at each stage
    window_size = 7,                # window size, which should be either 7 or 14
    num_classes = 1000,             # number of output classes
    tokenize_local_3_conv = False,  # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
    use_peg = False,                # whether to use positional generating module. they used this for object detection for a boost in performance
)

img = tf.random.normal([1, 224, 224, 3])

pred = model(img) # (1, 1000)

CrossFormer

This paper beats PVT and Swin using alternating local and global attention. The global attention is done across the windowing dimension for reduced complexity, much like the scheme used for axial attention.

They also have cross-scale embedding layer, which they shown to be a generic layer that can improve all vision transformers. Dynamic relative positional bias was also formulated to allow the net to generalize to images of greater resolution.

import tensorflow as tf
from vit_tensorflow.crossformer import CrossFormer

model = CrossFormer(
    num_classes = 1000,                # number of output classes
    dim = (64, 128, 256, 512),         # dimension at each stage
    depth = (2, 2, 8, 2),              # depth of transformer at each stage
    global_window_size = (8, 4, 2, 1), # global window sizes at each stage
    local_window_size = 7,             # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
)

img = tf.random.normal([1, 224, 224, 3])

pred = model(img) # (1, 1000)

ScalableViT

This Bytedance AI paper proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (reduction_factor), while modulating the dimension of the queries and keys (ssa_dim_key). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).

They make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.

You can use it as follows (ex. ScalableViT-S)

import tensorflow as tf
from vit_tensorflow.scalable_vit import ScalableViT

model = ScalableViT(
    num_classes = 1000,
    dim = 64,                               # starting model dimension. at every stage, dimension is doubled
    heads = (2, 4, 8, 16),                  # number of attention heads at each stage
    depth = (2, 2, 20, 2),                  # number of transformer blocks at each stage
    ssa_dim_key = (40, 40, 40, 32),         # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
    reduction_factor = (8, 4, 2, 1),        # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
    window_size = (64, 32, None, None),     # window size of the IWSA at each stage. None means no windowing needed
    dropout = 0.1,                          # attention and feedforward dropout
)

img = tf.random.normal([1, 256, 256, 3])

preds = model(img) # (1, 1000)

NesT

This paper decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.

You can use it with the following code (ex. NesT-T)

import tensorflow as tf
from vit_tensorflow.nest import NesT

nest = NesT(
    image_size = 224,
    patch_size = 4,
    dim = 96,
    heads = 3,
    num_hierarchies = 3,        # number of hierarchies
    block_repeats = (2, 2, 8),  # the number of transformer blocks at each heirarchy, starting from the bottom
    num_classes = 1000
)

img = tf.random.normal([1, 224, 224, 3])

pred = nest(img) # (1, 1000)

MobileViT

This paper introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different perspective for the global processing of information with transformers.

You can use it with the following code (ex. mobilevit_xs)

import tensorflow as tf
from vit_tensorflow.mobile_vit import MobileViT

mbvit_xs = MobileViT(
    image_size = (256, 256),
    dims = [96, 120, 144],
    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
    num_classes = 1000
)

img = tf.random.normal([1, 256, 256, 3])

pred = mbvit_xs(img) # (1, 1000)

Simple Masked Image Modeling

This paper proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.

You can use this as follows

import tensorflow as tf
from vit_tensorflow import ViT
from vit_tensorflow.simmim import SimMIM

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

mim = SimMIM(
    encoder = v,
    masking_ratio = 0.5  # they found 50% to yield the best results
)

images = tf.random.normal([8, 256, 256, 3])

loss = mim(images)

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

Masked Autoencoder

A new Kaiming He paper proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.

DeepReader quick paper review

AI Coffeebreak with Letitia

You can use it with the following code

import tensorflow as tf
from vit_tensorflow import ViT, MAE

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

mae = MAE(
    encoder = v,
    masking_ratio = 0.75,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
)

images = tf.random.normal([8, 256, 256, 3])

loss = mae(images)

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

Masked Patch Prediction

Thanks to Zach, you can train using the original masked patch prediction task presented in the paper, with the following code.

import tensorflow as tf
from vit_tensorflow import ViT
from vit_tensorflow.mpp import MPP

model = ViT(
    image_size=256,
    patch_size=32,
    num_classes=1000,
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1
)

mpp_trainer = MPP(
    transformer=model,
    patch_size=32,
    dim=1024,
    mask_prob=0.15,          # probability of using token in masked prediction task
    random_patch_prob=0.30,  # probability of randomly replacing a token being used for mpp
    replace_prob=0.50,       # probability of replacing a token being used for mpp with the mask token
)

def sample_unlabelled_images():
    return tf.random.normal([20, 256, 256, 3])

for _ in range(100):
    with tf.GradientTape() as tape:
        images = sample_unlabelled_images()
        loss = mpp_trainer(images)

Adaptive Token Sampling

This paper proposes to use the CLS attention scores, re-weighed by the norms of the value heads, as means to discard unimportant tokens at different layers.

import tensorflow as tf
from vit_tensorflow.ats_vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = tf.random.normal([4, 256, 256, 3])

preds = v(img) # (4, 1000)

# you can also get a list of the final sampled patch ids
# a value of -1 denotes padding

preds, token_ids = v(img, return_sampled_token_ids = True) # (4, 1000), (4, <=8)

Patch Merger

This paper proposes a simple module (Patch Merger) for reducing the number of tokens at any layer of a vision transformer without sacrificing performance.

import tensorflow as tf
from vit_tensorflow.vit_with_patch_merger import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 12,
    heads = 8,
    patch_merge_layer = 6,        # at which transformer layer to do patch merging
    patch_merge_num_tokens = 8,   # the output number of tokens from the patch merge
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = tf.random.normal([4, 256, 256, 3])

preds = v(img) # (4, 1000)

One can also use the PatchMerger module by itself

import tensorflow as tf
from vit_tensorflow.vit_with_patch_merger import PatchMerger

merger = PatchMerger(
    dim = 1024,
    num_tokens_out = 8   # output number of tokens
)

features = tf.random.normal([4, 256, 1024]) # (batch, num tokens, dimension)

out = merger(features) # (4, 8, 1024)

Vision Transformer for Small Datasets

This paper proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the LSA with the learned temperature and masking out of a token's attention to itself.

You can use as follows:

import tensorflow as tf
from vit_tensorflow.vit_for_small_dataset import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = tf.random.normal([4, 256, 256, 3])

preds = v(img) # (1, 1000)

You can also use the SPT from this paper as a standalone module

import tensorflow as tf
from vit_tensorflow.vit_for_small_dataset import SPT

spt = SPT(
    dim = 1024,
    patch_size = 16,
    channels = 3
)

img = tf.random.normal([4, 256, 256, 3])

tokens = spt(img) # (4, 256, 1024)

Parallel ViT

This paper propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance.

You can try this variant as follows

import tensorflow as tf
from vit_tensorflow.parallel_vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    num_parallel_branches = 2,  # in paper, they claimed 2 was optimal
    dropout = 0.1,
    emb_dropout = 0.1
)

img = tf.random.normal([4, 256, 256, 3])

preds = v(img) # (4, 1000)

FAQ

  • How do I pass in non-square images?

You can already pass in non-square images - you just have to make sure your height and width is less than or equal to the image_size, and both divisible by the patch_size

ex.

import tensorflow as tf
from vit_tensorflow import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = tf.random.normal([1, 256, 128, 3]) # <-- not a square

preds = v(img) # (1, 1000)
  • How do I pass in non-square patches?
import tensorflow as tf
from vit_tensorflow import ViT

v = ViT(
    num_classes = 1000,
    image_size = (256, 128),  # image size is a tuple of (height, width)
    patch_size = (32, 16),    # patch size is a tuple of (height, width)
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = tf.random.normal([1, 256, 128, 3])

preds = v(img)

Resources

Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.

  1. Illustrated Transformer - Jay Alammar

  2. Transformers from Scratch - Peter Bloem

  3. The Annotated Transformer - Harvard NLP

More Repositories

1

UGATIT

Official Tensorflow implementation of U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation (ICLR 2020)
Python
6,182
star
2

Tensorflow-Cookbook

Simple Tensorflow Cookbook for easy-to-use
Python
2,798
star
3

SENet-Tensorflow

Simple Tensorflow implementation of "Squeeze and Excitation Networks" using Cifar10 (ResNeXt, Inception-v4, Inception-resnet-v2)
Python
750
star
4

StarGAN-Tensorflow

Simple Tensorflow implementation of StarGAN (CVPR 2018 Oral)
Python
715
star
5

Self-Attention-GAN-Tensorflow

Simple Tensorflow implementation of "Self-Attention Generative Adversarial Networks" (SAGAN)
Python
540
star
6

Densenet-Tensorflow

Simple Tensorflow implementation of Densenet using Cifar10, MNIST
Python
507
star
7

SPADE-Tensorflow

Simple Tensorflow implementation of "Semantic Image Synthesis with Spatially-Adaptive Normalization" a.k.a. GauGAN, SPADE (CVPR 2019 Oral)
Python
359
star
8

MUNIT-Tensorflow

Simple Tensorflow implementation of "Multimodal Unsupervised Image-to-Image Translation" (ECCV 2018)
Python
299
star
9

Vector_Similarity

Python, Java implementation of TS-SS called from "A Hybrid Geometric Approach for Measuring Similarity Level Among Documents and Document Clustering"
Python
285
star
10

Tensorflow2-Cookbook

Simple Tensorflow 2.x Cookbook for easy-to-use
Python
266
star
11

BigGAN-Tensorflow

Simple Tensorflow implementation of "Large Scale GAN Training for High Fidelity Natural Image Synthesis" (BigGAN)
Python
261
star
12

CartoonGAN-Tensorflow

Simple Tensorflow implementation of CartoonGAN (CVPR 2018)
Python
217
star
13

StyleGAN-Tensorflow

Simple & Intuitive Tensorflow implementation of StyleGAN (CVPR 2019 Oral)
Python
211
star
14

GAN_Metrics-Tensorflow

Simple Tensorflow implementation of metrics for GAN evaluation (Inception score, Frechet-Inception distance, Kernel-Inception distance)
Python
205
star
15

ResNet-Tensorflow

Simple Tensorflow implementation of pre-activation ResNet18, 34, 50, 101, 152
Python
179
star
16

ResNeXt-Tensorflow

Simple Tensorflow implementation of ResNeXt using Cifar10
Python
159
star
17

AdaBound-Tensorflow

Simple Tensorflow implementation of "Adaptive Gradient Methods with Dynamic Bound of Learning Rate" (ICLR 2019)
Python
150
star
18

Spectral_Normalization-Tensorflow

Simple Tensorflow Implementation of "Spectral Normalization for Generative Adversarial Networks" (ICLR 2018)
Python
140
star
19

DRIT-Tensorflow

Simple Tensorflow implementation of "Diverse Image-to-Image Translation via Disentangled Representations" (ECCV 2018 Oral)
Python
117
star
20

StarGAN_v2-Tensorflow

Simple Tensorflow implementation of StarGAN_v2
Python
112
star
21

AMSGrad-Tensorflow

Simple Tensorflow implementation of "On the Convergence of Adam and Beyond" (ICLR 2018)
Python
103
star
22

RAdam-Tensorflow

Simple Tensorflow implementation of "On The Variance Of The Adaptive Learning Rate And Beyond"
Python
97
star
23

UNIT-Tensorflow

Simple Tensorflow implementation of "Unsupervised Image to Image Translation Networks" (NIPS 2017 Spotlight)
Python
96
star
24

Awesome-DeepLearning-Study

Summary of DeepLearning (Korean and English are included)
Python
93
star
25

partial_conv-Tensorflow

Simple Tensorflow implementation of "Partial Convolution based Padding" (partialconv)
Python
90
star
26

FusionGAN-Tensorflow

Simple Tensorflow implementation of FusionGAN (CVPR 2018)
Python
79
star
27

Tensorflow-DatasetAPI

Simple Tensorflow DatasetAPI Tutorial for reading image
Python
73
star
28

TripleGAN-Tensorflow

Simple Tensorflow implementation of Triple Generative Adversarial Nets (NIPS 2017)
Python
68
star
29

FUNIT-Tensorflow

Simple Tensorflow implementation of "Few-Shot Unsupervised Image-to-Image Translation" (ICCV 2019)
Python
65
star
30

SphereGAN-Tensorflow

Simple Tensorflow implementation of SphereGAN (CVPR 2019 Oral)
Python
57
star
31

GDWCT-Tensorflow

Simple Tensorflow implementation of "Image-to-Image Translation via Group-wise Deep Whitening-and-Coloring Transformation" (CVPR 2019 Oral)
Python
57
star
32

RelativisticGAN-Tensorflow

Simple Tensorflow implementation of RelativisticGAN
Python
51
star
33

AdamP-Tensorflow

Tensorflow Implementation of "Slowing Down the Weight Norm Increase in Momentum-based Optimizers"
Python
47
star
34

Batch_Instance_Normalization-Tensorflow

Simple Tensorflow implementation of Batch-Instance Normalization (NIPS 2018)
Python
40
star
35

GCNet-Tensorflow

Simple Tensorflow implementation of "GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond"
Python
39
star
36

Diffusion-Tensorflow

Tensorflow implementations of Diffusion models (DDPM, DDIM)
Python
37
star
37

CNN-Architecture-Summary

Simple Summary of CNN Architecture
34
star
38

tf-torch-template

Deep learning project template with tensorflow & pytorch (multi-gpu version)
Python
32
star
39

Switchable_Normalization-Tensorflow

Simple Tensorflow implementation of "Switchable Normalization"
Python
29
star
40

AdaConv-Tensorflow

Simple Tensorflow implementation of "Adaptive Convolutions for Structure-Aware Style Transfer" (CVPR 2021)
Python
26
star
41

AttnGAN-Tensorflow

Simple Tensorflow implementation of "AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks" (CVPR 2018)
Python
26
star
42

CLIP-Tensorflow

Simple Tensorflow implementation of CLIP
Python
24
star
43

Image_similarity_with_elastic_search

Find the original image of the converted image with elastic search
Python
22
star
44

RealnessGAN-Tensorflow

Simple Tensorflow implementation of "RealnessGAN: Real or Not Real, that is the Question" (ICLR 2020 Spotlight)
Python
22
star
45

ControlGAN-Tensorflow

Simple Tensorflow implementation of "ControlGAN: Controllable Text-to-Image Generation" (NeurIPS 2019)
Python
19
star
46

CycleGAN-Tensorflow

Simple Tensorflow implementation of CycleGAN
Python
18
star
47

diffusion-pytorch

์ดํ™”์—ฌ๋Œ€ ๊ฐ•์˜์ž๋ฃŒ
Python
18
star
48

SRM-Tensorflow

Simple Tensorflow implementation of "SRM : A Style-based Recalibration Module for Convolutional Neural Networks"
Python
18
star
49

GAN-Tensorflow

An implementation of GAN using TensorFlow
Python
17
star
50

SDIT-Tensorflow

Simple Tensorflow implementation of "SDIT: Scalable and Diverse Cross-domain Image Translation" (ACM-MM 2019)
Python
16
star
51

StableGAN-Tensorflow

Simple Tensorflow implementation of "Stabilizing Adversarial Nets With Prediction Methods" (ICLR 2018)
Python
16
star
52

Toward_spatial_unbiased-Tensorflow

Simple Tensorflow implementation of "Toward Spatially Unbiased Generative Models" (ICCV 2021)
Python
16
star
53

denoising-diffusion-gan-Tensorflow

Tensorflow implementation of "Tackling the Generative Learning Trilemma with Denoising Diffusion GANs" (ICLR 2022 Spotlight)
Python
15
star
54

MirrorGAN-Tensorflow

Simple Tensorflow implementation of "MirrorGAN: Learning Text-to-image Generation by Redescription" (CVPR 2019)
Python
15
star
55

Word2VecJava

Word2Vec In Java (2013 google word2vec opensource)
Java
14
star
56

StackGAN-Tensorflow

Simple Tensorflow implementation of "StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks" (ICCV 2017 Oral)
Python
13
star
57

MDGAN-Tensorflow

Simple Tensorflow implementation of "MDGAN: Mixture Density Generative Adversarial Networks" (CVPR 2019)
Python
11
star
58

DiscoGAN-Tensorflow

Simple Tensorflow implementation of DiscoGAN
Python
11
star
59

stylegan2-pytorch

Pytorch implementation of StyleGAN2 in my style
Python
11
star
60

pix2pix-Tensorflow

SImple Tensorflow implementations of " Image-to-Image Translation with Conditional Adversarial Networks" (CVPR 2017)
Python
11
star
61

DCGAN-Tensorflow

SImple Tensorflow implementation of "Deep Convolutional Generative Adversarial Networks"
Python
10
star
62

taki0112

8
star
63

coding_interview

Programmers coding interview in Korean
Python
8
star
64

CIPS-Tensorflow

Simple Tensorflow implementation of "Image Generators with Conditionally-Independent Pixel Synthesis" (CVPR 2021 Oral)
Python
7
star
65

Image_classification_CNN-Tensorflow

Classify dog and cat images of kaggle data
Python
7
star
66

TFIDF_Java

Get TF-IDF of Words
Java
4
star
67

CNN_Tensorflow

Convolutional Neural Network with Tensorflow, MNIST data
Python
4
star
68

grid_sample-Tensorflow

Tensorflow implementation of the grid_sample of pytorch.
Python
3
star
69

NiN-Tensorflow

Simple Tensorflow implementation of Network in Network
Python
2
star
70

Naver-Keyword_Analysis

Keyword Analysis from Naver Hack Day
Java
2
star
71

Deep-Q-network

Reinforcement study
Python
1
star
72

mnist_embedding

Python
1
star
73

Bamboo

1
star