• Stars
    star
    135
  • Rank 269,297 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch reimplementation of the paper "MaxViT: Multi-Axis Vision Transformer" [arXiv 2022].

MaxViT: Multi-Axis Vision Transformer

License: MIT

Unofficial PyTorch reimplementation of the paper MaxViT: Multi-Axis Vision Transformer by Zhengzhong Tu et al. (Google Research).

1

Figure taken from paper.

Note timm offers pre-trained MaxViT weights on ImageNet!

Installation

You can simply install the MaxViT implementation as a Python package by using pip.

pip install git+https://github.com/ChristophReich1996/MaxViT

Alternatively, you can clone the repository and use the implementation in maxvit directly in your project.

This implementation only relies on PyTorch and Timm ( see requirements.txt).

Usage

This implementation provides the pre-configured models of the paper (tiny, small, base, and large 224 X 224), which can be used as:

import torch
import maxvit

# Tiny model
network: maxvit.MaxViT = maxvit.max_vit_tiny_224(num_classes=1000)
input = torch.rand(1, 3, 224, 224)
output = network(input)

# Small model
network: maxvit.MaxViT = maxvit.max_vit_small_224(num_classes=365, in_channels=1)
input = torch.rand(1, 1, 224, 224)
output = network(input)

# Base model
network: maxvit.MaxViT = maxvit.max_vit_base_224(in_channels=4)
input = torch.rand(1, 4, 224, 224)
output = network(input)

# Large model
network: maxvit.MaxViT = maxvit.max_vit_large_224()
input = torch.rand(1, 3, 224, 224)
output = network(input)

To accesses the named weights of the network which are not recommended being used with weight decay call nwd: Set[str] = network.no_weight_decay().

In case you want to use a custom configuration you can use the MaxViT class. The constructor method takes the following parameters.

Parameter Description Type
in_channels Number of input channels to the convolutional stem. Default 3 int, optional
depths Depth of each network stage. Default (2, 2, 5, 2) Tuple[int, ...], optional
channels Number of channels in each network stage. Default (64, 128, 256, 512) Tuple[int, ...], optional
num_classes Number of classes to be predicted. Default 1000 int, optional
embed_dim Embedding dimension of the convolutional stem. Default 64 int, optional
num_heads Number of attention heads. Default 32 int, optional
grid_window_size Grid/Window size to be utilized. Default (7, 7) Tuple[int, int], optional
attn_drop Dropout ratio of attention weight. Default: 0.0 float, optional
drop Dropout ratio of output. Default: 0.0 float, optional
drop_path Dropout ratio of path. Default: 0.0 float, optional
mlp_ratio Ratio of mlp hidden dim to embedding dim. Default: 4.0 float, optional
act_layer Type of activation layer to be utilized. Default: nn.GELU Type[nn.Module], optional
norm_layer Type of normalization layer to be utilized. Default: nn.BatchNorm2d Type[nn.Module], optional
norm_layer_transformer Normalization layer in Transformer. Default: nn.LayerNorm Type[nn.Module], optional
global_pool Global polling type to be utilized. Default "avg" str, optional

Disclaimer

This is a very experimental implementation only based on the MaxViT paper. Since an official implementation of the MaxViT is not yet published, it is not possible to say to which extent this implementation might differ from the original one. If you have any issues with this implementation please raise an issue.

Reference

@article{Liu2021,
    title={{MaxViT: Multi-Axis Vision Transformer}},
    author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan 
            and Li, Yinxiao}
    journal={arXiv preprint arXiv:2204.01697},
    year={2022}
}

More Repositories

1

Swin-Transformer-V2

PyTorch reimplementation of the paper "Swin Transformer V2: Scaling Up Capacity and Resolution" [CVPR 2022].
Python
141
star
2

Involution

PyTorch reimplementation of the paper "Involution: Inverting the Inherence of Convolution for Visual Recognition" (2D and 3D Involution) [CVPR 2021].
Python
101
star
3

Cell-DETR

Official and maintained implementation of the paper "Attention-Based Transformers for Instance Segmentation of Cells in Microstructures" [BIBM 2020].
Python
88
star
4

Semantic_Pyramid_for_Image_Generation

PyTorch reimplementation of the paper: "Semantic Pyramid for Image Generation" [CVPR 2020].
Python
46
star
5

Mode_Collapse

Mode collapse example of GANs in 2D (PyTorch).
Python
30
star
6

ECG_Classification

Official and maintained implementation of the paper "Exploring Novel Algorithms for Atrial Fibrillation Detection by Driving Graduate Level Education in Medical Machine Learning" (ECG-DualNet) [Physiological Measurement 2022].
Python
27
star
7

OSS-Net

Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].
Python
24
star
8

SmeLU

PyTorch reimplementation of the Smooth ReLU activation function proposed in the paper "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations" [arXiv 2022].
Python
17
star
9

ToeffiPy

ToeffiPy is a PyTorch like autograd/deep learning library based only on NumPy.
Python
16
star
10

Dirac-GAN

PyTorch reimplementation of the DiracGAN proposed in the paper "Which Training Methods for GANs do actually Converge?" [ICML 2018].
Python
16
star
11

Optical-Flow-Visualization-PyTorch

PyTorch implementation of the classical optical flow visualization by Baker et al. [ICCV 2007].
Python
13
star
12

HyperMixer

PyTorch reimplementation of the paper "HyperMixer: An MLP-based Green AI Alternative to Transformers" [arXiv 2022].
Python
13
star
13

Multi-StyleGAN

Official and maintained implementation of the paper "Multi-StyleGAN: Towards Image-Based Simulation of Time-Lapse Live-Cell Microscopy" [MICCAI 2021].
Python
10
star
14

DeepFoveaPP_for_Video_Reconstruction_and_Super_Resolution

DeepFovea++: Reconstruction and Super-Resolution for Natural Foveated Rendered Videos (PyTorch).
Python
10
star
15

FNet2D

FNet 2D: Scaling Fourier Transform Token Mixing To Vision
Python
7
star
16

Differentiable_JPEG

This repo reimplements the differentiable JPEG proposed in "JPEG-resistant Adversarial Images".
Python
6
star
17

Yeast-in-Microstructures-Dataset

Official and maintained implementation of the dataset paper "An Instance Segmentation Dataset of Yeast Cells in Microstructures" [EMBC 2023].
Python
6
star
18

Pade-Activation-Unit

PyTorch reimplementation of the paper "Padé Activation Units: End-to-end Learning of Flexible Activation Functions in Deep Networks" [ICLR 2020].
Python
5
star
19

Scaling_Vision_Transformers_22B_Param

Reimplementation of the paper "Scaling Vision Transformers to 22 Billion Parameters" by Dehghani et al. [arXiv, 2023]
3
star
20

DL4NLP_Cheatsheet_TUD

Cheatsheet for the lecture Deep Learning for Natural Language Processing at TU Darmstadt
3
star
21

3D_Baggage_Segmentation

This repo implements a 3D segmentation task for an airport baggage dataset.
Python
2
star
22

SmeLU-Triton

Triton reimplementation of the Smooth ReLU activation function proposed in the paper "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations" [arXiv 2022].
Python
2
star
23

simple_logistic_regression

Simple logistic regression model with autograd for Statistics II
Python
1
star
24

CV2_Cheatsheet_TUD

Cheatsheet for the lecture Computer Vision at TU Darmstadt
1
star
25

Latex_Auto_Compile

Short python script for auto compiling Latex files.
Python
1
star
26

Elektronik_Formelsammlung_TUD

Formelsammlung für das Modul Elektronik (TU Darmstadt).
1
star
27

Neural_Network_cpp

Neural network from scratch in C++.
C++
1
star