• Stars
    star
    126
  • Rank 284,543 (Top 6 %)
  • Language
    Python
  • License
    Other
  • Created about 3 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Official Implementation of ResViT: Residual Vision Transformers for Multi-modal Medical Image Synthesis

ResViT

Official Pytorch Implementation of Residual Vision Transformers(ResViT) which is described in the following paper:

O. Dalmaz, M. Yurt and T. ร‡ukur, "ResViT: Residual Vision Transformers for Multimodal Medical Image Synthesis," in IEEE Transactions on Medical Imaging, vol. 41, no. 10, pp. 2598-2614, Oct. 2022, doi: 10.1109/TMI.2022.3167808.

Dependencies

python>=3.6.9
torch>=1.7.1
torchvision>=0.8.2
visdom
dominate
scikit-image
h5py
scipy
ml_collections
cuda=>11.2

Installation

  • Clone this repo:
git clone https://github.com/icon-lab/ResViT
cd ResViT

Download pre-trained ViT models from Google

wget https://storage.googleapis.com/vit_models/imagenet21k/R50+ViT-B_16.npz &&
mkdir ../model/vit_checkpoint/imagenet21k &&
mv {MODEL_NAME}.npz ../model/vit_checkpoint/imagenet21k/R50-ViT-B_16.npz

Dataset

To reproduce the results reported in the paper, we recommend the following dataset processing steps:

Sequentially select subjects from the dataset. Apply skull-stripping to 3D volumes. Select 2D cross-sections from each subject. Normalize the selected 2D cross-sections before training and before metric calculation. You should structure your aligned dataset in the following way:

/Datasets/BRATS/
  โ”œโ”€โ”€ T1_T2
  โ”œโ”€โ”€ T2_FLAIR
  .
  .
  โ”œโ”€โ”€ T1_FLAIR_T2   
/Datasets/BRATS/T2__FLAIR/
  โ”œโ”€โ”€ train
  โ”œโ”€โ”€ val  
  โ”œโ”€โ”€ test   

Note that for many-to-one tasks, source modalities should be in the Red and Green channels. (For 2 input modalities)

Pre-training of ART blocks without the presence of transformers

It is recommended to pretrain the convolutional parts of the ResViT model before inserting transformer modules and fine-tuning. This signifcantly improves ResViT's.

For many-to-one tasks:


python3 train.py --dataroot Datasets/IXI/T1_T2__PD/ --name T1_T2_PD_IXI_pre_trained --gpu_ids 0 --model resvit_many --which_model_netG res_cnn 
--which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 3 --loadSize 256 --fineSize 256 
--niter 50 --niter_decay 50 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --lr 0.0002


For one-to-one tasks:
python3 train.py --dataroot Datasets/IXI/T1_T2/ --name T1_T2_IXI_pre_trained --gpu_ids 0 --model resvit_one --which_model_netG res_cnn 
--which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 
--niter 50 --niter_decay 50 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --lr 0.0002


Fine tune ResViT

For many-to-one tasks:

python3 train.py --dataroot Datasets/IXI/T1_T2__PD/ --name T1_T2_PD_IXI_resvit --gpu_ids 0 --model resvit_many --which_model_netG resvit 
--which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 3 --loadSize 256 --fineSize 256 
--niter 25 --niter_decay 25 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --pre_trained_transformer 1 --pre_trained_resnet 1 
--pre_trained_path checkpoints/T1_T2_PD_IXI_pre_trained/latest_net_G.pth --lr 0.001


For one-to-one tasks:
python3 train.py --dataroot Datasets/IXI/T1_T2/ --name T1_T2_IXI_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit 
--which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 
--niter 25 --niter_decay 25 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --pre_trained_transformer 1 --pre_trained_resnet 1 
--pre_trained_path checkpoints/T1_T2_IXI_pre_trained/latest_net_G.pth --lr 0.001


Testing

For many-to-one tasks:

python3 test.py --dataroot Datasets/IXI/T1_T2__PD/ --name T1_T2_PD_IXI_resvit --gpu_ids 0 --model resvit_many --which_model_netG resvit 
--dataset_mode aligned --norm batch --phase test --output_nc 1 --input_nc 3 --how_many 10000 --serial_batches --fineSize 256 --loadSize 256 
--results_dir results/ --checkpoints_dir checkpoints/ --which_epoch latest


For one-to-one tasks:
python3 test.py --dataroot Datasets/IXI/T1_T2/ --name T1_T2_IXI_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit 
--dataset_mode aligned --norm batch --phase test --output_nc 1 --input_nc 1 --how_many 10000 --serial_batches --fineSize 256 --loadSize 256 
--results_dir results/ --checkpoints_dir checkpoints/ --which_epoch latest

Citation

You are encouraged to modify/distribute this code. However, please acknowledge this code and cite the paper appropriately.

@ARTICLE{9758823,
  author={Dalmaz, Onat and Yurt, Mahmut and ร‡ukur, Tolga},
  journal={IEEE Transactions on Medical Imaging}, 
  title={ResViT: Residual Vision Transformers for Multimodal Medical Image Synthesis}, 
  year={2022},
  volume={41},
  number={10},
  pages={2598-2614},
  doi={10.1109/TMI.2022.3167808}}

For any questions, comments and contributions, please contact Onat Dalmaz (onat[at]ee.bilkent.edu.tr)

(c) ICON Lab 2021

Acknowledgments

This code uses libraries from pGAN and pix2pix repository.

More Repositories

1

SynDiff

Official PyTorch implementation of SynDiff described in the paper (https://arxiv.org/abs/2207.08208).
Python
223
star
2

pGAN-cGAN

Official implementations of the pixel-wise and cycle-consistency GAN models for multi-contrast MRI synthesis
Python
60
star
3

I2I-Mamba

Official implementation of I2I-Mamba, an image-to-image translation model based on selective state spaces
Python
54
star
4

AdaDiff

Official PyTorch implementation of AdaDiff described in the paper (https://arxiv.org/abs/2207.05876).
Python
48
star
5

SLATER

Official implementation of the paper: Unsupervised MRI Reconstruction via Zero-Shot Learned Adversarial Transformers
Python
36
star
6

BolT

Fused Window Transformers for fMRI Time Series Analysis (https://www.sciencedirect.com/science/article/pii/S1361841523001019)
Python
29
star
7

SelfRDB

Official PyTorch implementation of SelfRDB, a diffusion bridge model for multi-modal medical image synthesis
Python
27
star
8

FDB

Official implementation of the Fourier-constrained diffusion bridges (FDB) model for MRI reconstruction
Python
25
star
9

mrirecon

ICON Lab @ Bilkent University
18
star
10

pFLSynth

One Model to Unite Them All: Personalized Federated Learning of Multi-Contrast MRI Synthesis (pFLSynth)
Python
17
star
11

FedGIMP

Official TensorFlow implementation of Federated Learning of Generative Image Priors for MRI Reconstruction (FedGIMP)
Python
13
star
12

HST

Official implementation of Hierarchical Spectrogram Transformers (HST)
Python
13
star
13

A-LORAKS-CS

Automated Parameter Selection for Accelerated MRI Reconstruction via Low-Rank Modeling of Local k-Space Neighborhoods
MATLAB
5
star
14

ssGAN

Official implementation of the semi-supervised GAN model for MRI contrast translation
Python
5
star
15

ProvoGAN

Official Implementation of Progressively Volumetrized Deep Generative Models for Data-Efficient Contextual Learning of MR Image Recovery
Python
4
star
16

TranSMS

Official Implementation of Transformers for System Matrix Super-resolution (TranSMS)
Python
4
star
17

DreaMR

Diffusion-driven Counterfactual Explanation for Functional MRI (https://arxiv.org/abs/2307.09547)
Python
4
star
18

CELF

Constrained Ellipse Fitting for Efficient Parameter Mapping With Phase-Cycled bSSFP MRI
MATLAB
4
star
19

SSDiffRecon

Official implementation of Self-Supervised Diffusion Model for MRI Reconstruction
Python
3
star
20

FD-Net

Official repository for "FD-Net: An Unsupervised Deep Forward-Distortion Model for Susceptibility Artifact Correction in EPI"
Python
3
star
21

ReCaT

MATLAB
2
star
22

PP-MPI

Official repository for Plug-n-Play MPI Reconstruction
Python
2
star
23

SPIN-VM

MATLAB
2
star
24

PSFNet

Official implementation of "Parallel-stream fusion of scan-specific and scan-general priors for learning deep MRI reconstruction in low-data regimes"
Python
2
star
25

DEQ-MPI

Official implementation of DEQ-MPI: A deep equilibrium reconstruction model for magnetic particle imaging
Python
1
star
26

PESCaT

Projection onto epigraph sets for rapid self-tuning compressed-sensing MRI
MATLAB
1
star