• Stars
    star
    797
  • Rank 56,747 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 3 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Official Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation" - MICCAI 2021

Medical-Transformer

Pytorch code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation", MICCAI 2021

Paper | Poster

News:

🚀 : Checkout our latest work UNeXt, a faster and more efficient segmentation architecture which is also easy to train and implement! Code is available here.

About this repo:

This repo hosts the code for the following networks:

  1. Gated Axial Attention U-Net
  2. MedT

Introduction

Majority of existing Transformer-based network architectures proposed for vision applications require large-scale datasets to train properly. However, compared to the datasets for vision applications, for medical imaging the number of data samples is relatively low, making it difficult to efficiently train transformers for medical appli- cations. To this end, we propose a Gated Axial-Attention model which extends the existing architectures by introducing an additional control mechanism in the self-attention module. Furthermore, to train the model effectively on medical images, we propose a Local-Global training strat- egy (LoGo) which further improves the performance. Specifically, we op- erate on the whole image and patches to learn global and local features, respectively. The proposed Medical Transformer (MedT) uses LoGo training strategy on Gated Axial Attention U-Net.

Using the code:

  • Clone this repository:
git clone https://github.com/jeya-maria-jose/Medical-Transformer
cd Medical-Transformer

The code is stable using Python 3.6.10, Pytorch 1.4.0

To install all the dependencies using conda:

conda env create -f environment.yml
conda activate medt

To install all the dependencies using pip:

pip install -r requirements.txt

Links for downloading the public Datasets:

  1. MoNuSeG Dataset - Link (Original)
  2. GLAS Dataset - Link (Original)
  3. Brain Anatomy US dataset from the paper will be made public soon !

Using the Code for your dataset

Dataset Preparation

Prepare the dataset in the following format for easy use of the code. The train and test folders should contain two subfolders each: img and label. Make sure the images their corresponding segmentation masks are placed under these folders and have the same name for easy correspondance. Please change the data loaders to your need if you prefer not preparing the dataset in this format.

Train Folder-----
      img----
          0001.png
          0002.png
          .......
      labelcol---
          0001.png
          0002.png
          .......
Validation Folder-----
      img----
          0001.png
          0002.png
          .......
      labelcol---
          0001.png
          0002.png
          .......
Test Folder-----
      img----
          0001.png
          0002.png
          .......
      labelcol---
          0001.png
          0002.png
          .......
  • The ground truth images should have pixels corresponding to the labels. Example: In case of binary segmentation, the pixels in the GT should be 0 or 255.

Training Command:

python train.py --train_dataset "enter train directory" --val_dataset "enter validation directory" --direc 'path for results to be saved' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no"
Change modelname to MedT or logo to train them

Testing Command:

python test.py --loaddirec "./saved_model_path/model_name.pth" --val_dataset "test dataset directory" --direc 'path for results to be saved' --batch_size 1 --modelname "gatedaxialunet" --imgsize 128 --gray "no"

The results including predicted segmentations maps will be placed in the results folder along with the model weights. Run the performance metrics code in MATLAB for calculating F1 Score and mIoU.

Notes:

1)Note that these experiments were conducted in Nvidia Quadro 8000 with 48 GB memory. 2)Google Colab Code is an unofficial implementation for quick train/test. Please follow original code for proper training.

Acknowledgement:

The dataloader code is inspired from pytorch-UNet . The axial attention code is developed from axial-deeplab.

Citation:

@InProceedings{jose2021medical,
author="Valanarasu, Jeya Maria Jose
and Oza, Poojan
and Hacihaliloglu, Ilker
and Patel, Vishal M.",
title="Medical Transformer: Gated Axial-Attention for Medical Image Segmentation",
booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2021",
year="2021",
publisher="Springer International Publishing",
address="Cham",
pages="36--46",
isbn="978-3-030-87193-2"
}

Open an issue or mail me directly in case of any queries or suggestions.

More Repositories

1

UNeXt-pytorch

Official Pytorch Code base for "UNeXt: MLP-based Rapid Medical Image Segmentation Network", MICCAI 2022
Python
462
star
2

KiU-Net-pytorch

Official Pytorch Code of KiU-Net for Image/3D Segmentation - MICCAI 2020 (Oral), IEEE TMI
Python
356
star
3

TransWeather

Pytorch Code for the paper TransWeather - CVPR 2022
Python
160
star
4

Cuff_less_BP_Prediction

Prediction of Blood Pressure from ECG and PPG signals using regression methods.
Python
146
star
5

Interactive-Portrait-Harmonization

Code Base for the work "Interactive Portrait Harmonization"
28
star
6

On-The-Fly-Adaptation

Code base for "On-the-Fly Test-time Adaptation for Medical Image Segmentation"
Python
26
star
7

Derain_OUCD_Net

Official Pytorch Code for "Exploring Overcomplete Representations for Single Image Deraining using CNNs" - IEEE Journal of STSP
Python
14
star
8

Overcomplete-Deep-Subspace-Clustering

Official Tensorflow Code for the paper "Overcomplete Deep Subspace Clustering Networks" - WACV 2021
Python
13
star
9

Image-Recovery-Using-Conditional-Adversarial-Networks

Analyzing Conditional Adversarial Networks to solve image recovery problems like shadow recovery, denoising and deblurring - CVIP 2019
Python
10
star
10

Unet_DWT

Unet based on Wavelet coefficients for segmentation
Python
8
star
11

essential_codes

Just a collection of essential set of codes that can be plugged in at needed places
Python
3
star
12

rebotnet-web

Website for Rebotnet
JavaScript
3
star
13

sparseSGD

Tweaking SGD by imposing sparsity to improve opimization for deep learning tasks.
Python
3
star
14

transweather-web

Website for TransWeather paper: https://jeya-maria-jose.github.io/transweather-web/
HTML
2
star
15

Vision-Based-Texting

Approach towards texting using eye gestures taken from a camera.
Python
1
star
16

Wind_LandArea_Aerial_Seg

Supporting Code for the paper "A novel application of deep learning to determine the actual land transformed by wind power"
Python
1
star
17

IPH-web

Website for "Interactive Portrait Harmonization"
JavaScript
1
star
18

quad-pid-control-

Python
1
star
19

Performance_Metrics

Python
1
star