• Stars
    star
    213
  • Rank 185,410 (Top 4 %)
  • Language
  • Created almost 4 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementing Vi(sion)T(transformer)

Implementing Vi(sual)T(transformer) in PyTorch

Hi guys, happy new year! Today we are going to implement the famous Vi(sual)T(transformer) proposed in AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE.

Code is here, an interactive version of this article can be downloaded from here.

ViT will be soon available on my new computer vision library called glasses

This is a technical tutorial, not your normal medium post where you find out about the top 5 secret pandas functions to make you rich.

So, before beginning, I highly recommend you to:

So, ViT uses a normal transformer (the one proposed in Attention is All You Need) that works on images. But, how?

The following picture shows ViT's architecture

alt

The input image is decomposed into 16x16 flatten patches (the image is not in scale). Then they are embedded using a normal fully connected layer, a special cls token is added in front of them and the positional encoding is summed. The resulting tensor is passed first into a standard Transformer and then to a classification head. That's it.

The article is structure into the following sections:

  • Data
  • Patches Embeddings
    • CLS Token
    • Position Embedding
  • Transformer
    • Attention
    • Residuals
    • MLP
    • TransformerEncoder
  • Head
  • ViT

We are going to implement the model block by block with a bottom-up approach. We can start by importing all the required packages

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

Nothing fancy here, just PyTorch + stuff

Data

First of all, we need a picture, a cute cat works just fine :)

img = Image.open('./cat.jpg')

fig = plt.figure()
plt.imshow(img)

alt

Then, we need to preprocess it

# resize to imagenet size 
transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0) # add batch dim
x.shape
torch.Size([1, 3, 224, 224])

Patches Embeddings

The first step is to break-down the image in multiple patches and flatten them.

alt

Quoting from the paper:

alt

This can be easily done using einops.

patch_size = 16 # 16 pixels
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)

Now, we need to project them using a normal linear layer

alt

We can create a PatchEmbedding class to keep our code nice and clean

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )
                
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x
    
PatchEmbedding()(x).shape
torch.Size([1, 196, 768])

Note After checking out the original implementation, I found out that the authors are using a Conv2d layer instead of a Linear one for performance gain. This is obtained by using a kernel_size and stride equal to the patch_size. Intuitively, the convolution operation is applied to each patch individually. So, we have to first apply the conv layer and then flat the resulting images.

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
                
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x
    
PatchEmbedding()(x).shape
torch.Size([1, 196, 768])

CLS Token

Next step is to add the cls token and the position embedding. The cls token is just a number placed in from of each sequence (of projected patches)

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        return x
    
PatchEmbedding()(x).shape
torch.Size([1, 197, 768])

cls_token is a torch Parameter randomly initialized, in the forward the method it is copied b (batch) times and prepended before the projected patches using torch.cat

Position Embedding

So far, the model has no idea about the original position of the patches. We need to pass this spatial information. This can be done in different ways, in ViT we let the model learn it. The position embedding is just a tensor of shape N_PATCHES + 1 (token), EMBED_SIZE that is added to the projected patches.

alt

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions
        return x
    
PatchEmbedding()(x).shape
torch.Size([1, 197, 768])

We added the position embedding in the .positions field and sum it to the patches in the .forward function

Transformer

Now we need the implement Transformer. In ViT only the Encoder is used, the architecture is visualized in the following picture.

drawing

Let's start with the Attention part

Attention

So, the attention takes three inputs, the famous queries, keys, and values, and computes the attention matrix using queries and values and use it to "attend" to the values. In this case, we are using multi-head attention meaning that the computation is split across n heads with smaller input size.

alt

We can use nn.MultiHadAttention from PyTorch or implement our own. For completeness I will show how it looks like:

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        self.scaling = (self.emb_size // num_heads) ** -0.5

    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        att = F.softmax(energy, dim=-1) * self.scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    
patches_embedded = PatchEmbedding()(x)
MultiHeadAttention()(patches_embedded).shape
torch.Size([1, 197, 768])

So, step by step. We have 4 fully connected layers, one for queries, keys, values, and a final one dropout.

Okay, the idea (really go and read The Illustrated Transformer ) is to use the product between the queries and the keys to knowing "how much" each element is the sequence in important with the rest. Then, we use this information to scale the values.

The forward method takes as input the queries, keys, and values from the previous layer and projects them using the three linear layers. Since we implementing multi heads attention, we have to rearrange the result in multiple heads.

This is done by using rearrange from einops.

Queries, Keys and Values are always the same, so for simplicity, I have only one input (x).

queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)

The resulting keys, queries, and values have a shape of BATCH, HEADS, SEQUENCE_LEN, EMBEDDING_SIZE.

To compute the attention matrix we first have to perform matrix multiplication between queries and keys, a.k.a sum up over the last axis. This can be easily done using torch.einsum

energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys

The resulting vector has the shape BATCH, HEADS, QUERY_LEN, KEY_LEN. Then the attention is finally the softmax of the resulting vector divided by a scaling factor based on the size of the embedding.

Lastly, we use the attention to scale the values

torch.einsum('bhal, bhlv -> bhav ', att, values)

and we obtain a vector of size BATCH HEADS VALUES_LEN, EMBEDDING_SIZE. We concat the heads together and we finally return the results

Note we can use a single matrix to compute in one shot queries, keys and values.

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    
patches_embedded = PatchEmbedding()(x)
MultiHeadAttention()(patches_embedded).shape
torch.Size([1, 197, 768])

Residuals

The transformer block has residuals connection

alt

We can create a nice wrapper to perform the residual addition, it will be handy later on

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

MLP

The attention's output is passed to a fully connected layer composed of two layers that upsample by a factor of expansion the input

drawing

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

Just a quick side note. I don't know why but I've never seen people subclassing nn.Sequential to avoid writing the forward method. Start doing it, this is how object programming works!

Finally, we can create the Transformer Encoder Block drawing

ResidualAdd allows us to define this block in an elegant way

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

Let's test it

patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape
torch.Size([1, 197, 768])

you can also PyTorch build-in multi-head attention but it will expect 3 inputs: queries, keys, and values. You can subclass it and pass the same input

Transformer

In ViT only the Encoder part of the original transformer is used. Easily, the encoder is L blocks of TransformerBlock.

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
                

Easy peasy!

Head

The last layer is a normal fully connect that gives the class probability. It first performs a basic mean over the whole sequence.

alt

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

Vi(sual) T(rasnformer)

We can compose PatchEmbedding, TransformerEncoder and ClassificationHead to create the final ViT architecture.

class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )
        

We can use torchsummary to check the number of parameters

summary(ViT(), (3, 224, 224), device='cpu')
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
    PatchEmbedding-3             [-1, 197, 768]               0
         LayerNorm-4             [-1, 197, 768]           1,536
            Linear-5            [-1, 197, 2304]       1,771,776
           Dropout-6          [-1, 8, 197, 197]               0
            Linear-7             [-1, 197, 768]         590,592
MultiHeadAttention-8             [-1, 197, 768]               0
           Dropout-9             [-1, 197, 768]               0
      ResidualAdd-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
           Linear-12            [-1, 197, 3072]       2,362,368
             GELU-13            [-1, 197, 3072]               0
          Dropout-14            [-1, 197, 3072]               0
           Linear-15             [-1, 197, 768]       2,360,064
          Dropout-16             [-1, 197, 768]               0
      ResidualAdd-17             [-1, 197, 768]               0
        LayerNorm-18             [-1, 197, 768]           1,536
           Linear-19            [-1, 197, 2304]       1,771,776
          Dropout-20          [-1, 8, 197, 197]               0
           Linear-21             [-1, 197, 768]         590,592
MultiHeadAttention-22             [-1, 197, 768]               0
          Dropout-23             [-1, 197, 768]               0
      ResidualAdd-24             [-1, 197, 768]               0
        LayerNorm-25             [-1, 197, 768]           1,536
           Linear-26            [-1, 197, 3072]       2,362,368
             GELU-27            [-1, 197, 3072]               0
          Dropout-28            [-1, 197, 3072]               0
           Linear-29             [-1, 197, 768]       2,360,064
          Dropout-30             [-1, 197, 768]               0
      ResidualAdd-31             [-1, 197, 768]               0
        LayerNorm-32             [-1, 197, 768]           1,536
           Linear-33            [-1, 197, 2304]       1,771,776
          Dropout-34          [-1, 8, 197, 197]               0
           Linear-35             [-1, 197, 768]         590,592
MultiHeadAttention-36             [-1, 197, 768]               0
          Dropout-37             [-1, 197, 768]               0
      ResidualAdd-38             [-1, 197, 768]               0
        LayerNorm-39             [-1, 197, 768]           1,536
           Linear-40            [-1, 197, 3072]       2,362,368
             GELU-41            [-1, 197, 3072]               0
          Dropout-42            [-1, 197, 3072]               0
           Linear-43             [-1, 197, 768]       2,360,064
          Dropout-44             [-1, 197, 768]               0
      ResidualAdd-45             [-1, 197, 768]               0
        LayerNorm-46             [-1, 197, 768]           1,536
           Linear-47            [-1, 197, 2304]       1,771,776
          Dropout-48          [-1, 8, 197, 197]               0
           Linear-49             [-1, 197, 768]         590,592
MultiHeadAttention-50             [-1, 197, 768]               0
          Dropout-51             [-1, 197, 768]               0
      ResidualAdd-52             [-1, 197, 768]               0
        LayerNorm-53             [-1, 197, 768]           1,536
           Linear-54            [-1, 197, 3072]       2,362,368
             GELU-55            [-1, 197, 3072]               0
          Dropout-56            [-1, 197, 3072]               0
           Linear-57             [-1, 197, 768]       2,360,064
          Dropout-58             [-1, 197, 768]               0
      ResidualAdd-59             [-1, 197, 768]               0
        LayerNorm-60             [-1, 197, 768]           1,536
           Linear-61            [-1, 197, 2304]       1,771,776
          Dropout-62          [-1, 8, 197, 197]               0
           Linear-63             [-1, 197, 768]         590,592
MultiHeadAttention-64             [-1, 197, 768]               0
          Dropout-65             [-1, 197, 768]               0
      ResidualAdd-66             [-1, 197, 768]               0
        LayerNorm-67             [-1, 197, 768]           1,536
           Linear-68            [-1, 197, 3072]       2,362,368
             GELU-69            [-1, 197, 3072]               0
          Dropout-70            [-1, 197, 3072]               0
           Linear-71             [-1, 197, 768]       2,360,064
          Dropout-72             [-1, 197, 768]               0
      ResidualAdd-73             [-1, 197, 768]               0
        LayerNorm-74             [-1, 197, 768]           1,536
           Linear-75            [-1, 197, 2304]       1,771,776
          Dropout-76          [-1, 8, 197, 197]               0
           Linear-77             [-1, 197, 768]         590,592
MultiHeadAttention-78             [-1, 197, 768]               0
          Dropout-79             [-1, 197, 768]               0
      ResidualAdd-80             [-1, 197, 768]               0
        LayerNorm-81             [-1, 197, 768]           1,536
           Linear-82            [-1, 197, 3072]       2,362,368
             GELU-83            [-1, 197, 3072]               0
          Dropout-84            [-1, 197, 3072]               0
           Linear-85             [-1, 197, 768]       2,360,064
          Dropout-86             [-1, 197, 768]               0
      ResidualAdd-87             [-1, 197, 768]               0
        LayerNorm-88             [-1, 197, 768]           1,536
           Linear-89            [-1, 197, 2304]       1,771,776
          Dropout-90          [-1, 8, 197, 197]               0
           Linear-91             [-1, 197, 768]         590,592
MultiHeadAttention-92             [-1, 197, 768]               0
          Dropout-93             [-1, 197, 768]               0
      ResidualAdd-94             [-1, 197, 768]               0
        LayerNorm-95             [-1, 197, 768]           1,536
           Linear-96            [-1, 197, 3072]       2,362,368
             GELU-97            [-1, 197, 3072]               0
          Dropout-98            [-1, 197, 3072]               0
           Linear-99             [-1, 197, 768]       2,360,064
         Dropout-100             [-1, 197, 768]               0
     ResidualAdd-101             [-1, 197, 768]               0
       LayerNorm-102             [-1, 197, 768]           1,536
          Linear-103            [-1, 197, 2304]       1,771,776
         Dropout-104          [-1, 8, 197, 197]               0
          Linear-105             [-1, 197, 768]         590,592
MultiHeadAttention-106             [-1, 197, 768]               0
         Dropout-107             [-1, 197, 768]               0
     ResidualAdd-108             [-1, 197, 768]               0
       LayerNorm-109             [-1, 197, 768]           1,536
          Linear-110            [-1, 197, 3072]       2,362,368
            GELU-111            [-1, 197, 3072]               0
         Dropout-112            [-1, 197, 3072]               0
          Linear-113             [-1, 197, 768]       2,360,064
         Dropout-114             [-1, 197, 768]               0
     ResidualAdd-115             [-1, 197, 768]               0
       LayerNorm-116             [-1, 197, 768]           1,536
          Linear-117            [-1, 197, 2304]       1,771,776
         Dropout-118          [-1, 8, 197, 197]               0
          Linear-119             [-1, 197, 768]         590,592
MultiHeadAttention-120             [-1, 197, 768]               0
         Dropout-121             [-1, 197, 768]               0
     ResidualAdd-122             [-1, 197, 768]               0
       LayerNorm-123             [-1, 197, 768]           1,536
          Linear-124            [-1, 197, 3072]       2,362,368
            GELU-125            [-1, 197, 3072]               0
         Dropout-126            [-1, 197, 3072]               0
          Linear-127             [-1, 197, 768]       2,360,064
         Dropout-128             [-1, 197, 768]               0
     ResidualAdd-129             [-1, 197, 768]               0
       LayerNorm-130             [-1, 197, 768]           1,536
          Linear-131            [-1, 197, 2304]       1,771,776
         Dropout-132          [-1, 8, 197, 197]               0
          Linear-133             [-1, 197, 768]         590,592
MultiHeadAttention-134             [-1, 197, 768]               0
         Dropout-135             [-1, 197, 768]               0
     ResidualAdd-136             [-1, 197, 768]               0
       LayerNorm-137             [-1, 197, 768]           1,536
          Linear-138            [-1, 197, 3072]       2,362,368
            GELU-139            [-1, 197, 3072]               0
         Dropout-140            [-1, 197, 3072]               0
          Linear-141             [-1, 197, 768]       2,360,064
         Dropout-142             [-1, 197, 768]               0
     ResidualAdd-143             [-1, 197, 768]               0
       LayerNorm-144             [-1, 197, 768]           1,536
          Linear-145            [-1, 197, 2304]       1,771,776
         Dropout-146          [-1, 8, 197, 197]               0
          Linear-147             [-1, 197, 768]         590,592
MultiHeadAttention-148             [-1, 197, 768]               0
         Dropout-149             [-1, 197, 768]               0
     ResidualAdd-150             [-1, 197, 768]               0
       LayerNorm-151             [-1, 197, 768]           1,536
          Linear-152            [-1, 197, 3072]       2,362,368
            GELU-153            [-1, 197, 3072]               0
         Dropout-154            [-1, 197, 3072]               0
          Linear-155             [-1, 197, 768]       2,360,064
         Dropout-156             [-1, 197, 768]               0
     ResidualAdd-157             [-1, 197, 768]               0
       LayerNorm-158             [-1, 197, 768]           1,536
          Linear-159            [-1, 197, 2304]       1,771,776
         Dropout-160          [-1, 8, 197, 197]               0
          Linear-161             [-1, 197, 768]         590,592
MultiHeadAttention-162             [-1, 197, 768]               0
         Dropout-163             [-1, 197, 768]               0
     ResidualAdd-164             [-1, 197, 768]               0
       LayerNorm-165             [-1, 197, 768]           1,536
          Linear-166            [-1, 197, 3072]       2,362,368
            GELU-167            [-1, 197, 3072]               0
         Dropout-168            [-1, 197, 3072]               0
          Linear-169             [-1, 197, 768]       2,360,064
         Dropout-170             [-1, 197, 768]               0
     ResidualAdd-171             [-1, 197, 768]               0
          Reduce-172                  [-1, 768]               0
       LayerNorm-173                  [-1, 768]           1,536
          Linear-174                 [-1, 1000]         769,000
================================================================
Total params: 86,415,592
Trainable params: 86,415,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 364.33
Params size (MB): 329.65
Estimated Total Size (MB): 694.56
----------------------------------------------------------------






(tensor(86415592), tensor(86415592), tensor(329.6493), tensor(694.5562))

et voilà

================================================================
Total params: 86,415,592
Trainable params: 86,415,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 364.33
Params size (MB): 329.65
Estimated Total Size (MB): 694.56
---------------------------------------------------------------

I checked the parameters with other implementations and they are the same!

Conclusions

In this article, we have seen how to implement ViT in a nice, scalable, and customizable way. I hope it was useful.

By the way, I am working on a new computer vision library called glasses, check it out if you like

Take care :)

Francesco

More Repositories

1

glasses

High-quality Neural Networks for Computer Vision 😎
Jupyter Notebook
360
star
2

PyTorch-Deep-Learning-Template

A Pytorch Computer Vision template to quick start your next project! 🚀🚀
Jupyter Notebook
303
star
3

Pytorch-how-and-when-to-use-Module-Sequential-ModuleList-and-ModuleDict

Code for my medium article
Jupyter Notebook
283
star
4

mirror

Visualisation tool for CNNs in pytorch
Jupyter Notebook
242
star
5

A-journey-into-Convolutional-Neural-Network-visualization-

A journey into Convolutional Neural Network visualization
Jupyter Notebook
235
star
6

ResNet

Clean, scalable and easy to use ResNet implementation in Pytorch
Jupyter Notebook
160
star
7

Tensorflow-Dataset-Tutorial

Notebook for my medium article about how to use Dataset API in TensorFlow
Jupyter Notebook
159
star
8

skeleton-card-vuejs

A reusable skeleton card component written in Vuejs
Vue
143
star
9

Reinforcement-Learning-Cheat-Sheet

Reinforcement Learning Cheat Sheet
TeX
131
star
10

API-Class

A utility class for calling apis CRUD methods
JavaScript
89
star
11

LinkedInGPT

Skynet
Python
86
star
12

DrawIo2Vuejs

A faster way to create a Vuejs app by using draw.io
JavaScript
82
star
13

Modern-Python-Doc-Example

mkdocs + material + cool stuff
Python
63
star
14

my-spaces

Run hugging face spaces locally with one command!
Python
53
star
15

linkedin_python

Python package to create posts on LinkedIn
Python
41
star
16

ConvNext

Implementing ConvNext in PyTorch
39
star
17

gradioGPT

Easy to hack template for your next chatGPT app with Gradio and Langchain
Python
37
star
18

PytorchModuleStorage

A easy to use API to store outputs from forward/backward hooks in Pytorch
Jupyter Notebook
35
star
19

TensorFlow-Serving-Example

Example of how to use TensorFlow serving
Python
35
star
20

purgIn

Chrome Extension to remove LinkedIn posts containing user defined words
JavaScript
33
star
21

torchserve-tryout

Deploy a CNN with torchserve using a custom handler
Python
31
star
22

drawIoToVuejs

Python
31
star
23

dynamic-batching-asyncio

Python
30
star
24

yolov10

Python
29
star
25

how-to-use-chatgpt-with-python

Tutorial about using ChatGPT APIs in Python
Python
28
star
26

FairytaleDJ

You got a friend in me
Python
26
star
27

Loading-huge-PyTorch-models-with-linear-memory-consumption

Little article showing how to load pytorch's models with linear memory consumption
21
star
28

pytorch-2.0-benchmark

Benchmarking PyTorch 2.0 different models
Python
20
star
29

search-all

Python
19
star
30

yolov11

Python
18
star
31

SegFormer

Implementation of SegFormer in PyTorch
17
star
32

DropPath

Implementing DropPath/StochasticDepth in PyTorch
16
star
33

Flue

Yep, another Flux implementation for Vuejs. Docs: https://francescosaveriozuppichini.github.io/Flue/header.html
JavaScript
14
star
34

detector

Python
14
star
35

non-max-suppression-in-pytorch

How to implement Non Max Suppression (NMS) in PyTorch
13
star
36

BottleNeck-InvertedResidual-FusedMBConv-in-PyTorch

A little walk-trough different types of the block with their corresponding implementation in PyTorch
Jupyter Notebook
13
star
37

RepVgg

Implementing RepVGG in PyTorch
Python
11
star
38

http-streaming-fastapi-js-playground

http-streaming-playground
TypeScript
10
star
39

LSTM-Text-Generator

LSTM written in Tensorflow that generates text
Python
8
star
40

is-3090-good-for-computer-vision

A collection of benchmarks I've run
Python
8
star
41

torchlego

High level building blocks for Neural Networks with examples
Python
8
star
42

DeiT

DeiT: Data-efficient Image Transformers
8
star
43

Face-Unlock

Face Unlock with Deep Learning
Jupyter Notebook
7
star
44

Resource

A more convenient way to store your state data using a map
JavaScript
6
star
45

DropBlock

Implementing DropBlock a better Dropout for Conv Nets in PyTorch!
6
star
46

Object365-download

Zero Dependencies script to download Object365
Python
5
star
47

chatgpt-action-fastapi

chatgpt-action-fastapi
Python
5
star
48

End2End-DataScienceProject

Jupyter Notebook
4
star
49

model-version-with-hf-hub

What if we use hf hub to do versioning of a model?
Python
4
star
50

Search-COVID-papers-with-Deep-Learning

A semantic browser using deep learning to search in COVID papers
Jupyter Notebook
4
star
51

local-youtube-rag

Local YouTube RAG with Qdrant and Ollama
Python
4
star
52

Paxos

Distributed Algorithm 2018 Project - USI
Python
3
star
53

any-inference

Run inference in any model using a message broker.
Python
3
star
54

StopWatchElectron

JavaScript
3
star
55

OneNet

OneNet
Python
3
star
56

How-To-Embed-in-TensorFlow

Code for my medium article
Jupyter Notebook
3
star
57

pytorch-distributed-collective-communication

Code and visualisations about PyTorch distributed collective communication
Python
3
star
58

auto_model_card

Little utility to create model card for huggingface hub
Python
3
star
59

python-autobump

A repo that contains a way to auto version bump based on gitmoji commits
Python
3
star
60

Distributed-Algorithm-USI-2018

notes for Distributed Algorithm course
3
star
61

yolov100

Python
3
star
62

data-gradients-hf-datasets

Using data-gradients with hugging face datasets
Jupyter Notebook
2
star
63

tips_pytorch

Tips for PyTorch
2
star
64

HuggingFaceAutoDocstring

Mustache
2
star
65

PytorchModulePCA

An easy to use API to visualize the latent space of CNN in Pytorch
Jupyter Notebook
2
star
66

.dotfiles

Hosting my home pc configuration
2
star
67

Emotions-Detection

Detection of emotions (happiness, sadness) from a face photo using Deep Learning
Jupyter Notebook
2
star
68

GraphAppCreator

JavaScript
2
star
69

py4ai

Python
2
star
70

COVID-19-Map-Storytelling

Covid-19 story told by maps
JavaScript
2
star
71

Faster-RCNN-tryout

Let's try out Faster-RCNN
Jupyter Notebook
2
star
72

Estimator

A predicatable way to train your deep learning model
Python
2
star
73

redux-promise-action-middleware

JavaScript
2
star
74

brainyquote-Web-Scraper

A easy to use web scraper to get quotes from brainyquote
JavaScript
2
star
75

playground-python

1
star
76

yolov42

Python
1
star
77

LinkedIn-posts

Repo holding my LinkedIn posts 💙
1
star
78

glasses-webapp

Webapp for my computer vision library glasses
JavaScript
1
star
79

FrancescoSaverioZuppichini

1
star
80

spammer

Python
1
star
81

yolov36

Python
1
star
82

Physical-Computing-Project

JavaScript
1
star
83

mobileone-segmentation-models-pytorch

lazy and raw work around to use mobile one in segmentation models pytorch
Python
1
star
84

auto-convert-notebooks-to-markdown-github-action

Python
1
star
85

yolov_-n-1-

hold code for yolov_{n+1}
Python
1
star
86

Activities

A vanilla bootstrap GUI for activity filtering based on 0-1 knapsack problem
JavaScript
1
star
87

glasses-2.0

Python
1
star
88

Fool-Object-Classifier

Repo used for my medium article
JavaScript
1
star
89

Master-Thesis

Repository for my Master Thesis @IDSIA and @USI
Jupyter Notebook
1
star
90

MarkDownToMediumThisNameAlreadyExists

Python
1
star
91

yolov99

Python
1
star
92

Mobile-Computing-Project

Application for Mobile Computing
1
star
93

simebot

Python
1
star
94

TFGraphConvertible

Jupyter Notebook
1
star
95

company-xyxy-challenge

Jupyter Notebook
1
star
96

Advance-Topics-In-Machine-Learning-Project

Project for the Advance Topics in Machine Learning course - USI 2018
Python
1
star
97

Robotics-2018

Python
1
star
98

fastapi-template

A little template with loguru and some sort of json logging
Python
1
star