• Stars
    star
    152
  • Rank 244,685 (Top 5 %)
  • Language
    Python
  • Created over 1 year ago
  • Updated 8 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation

3DSAM-adapter: Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation

Implementation for the paper 3DSAM-adapter: Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation by Shizhan Gong, Yuan Zhong, Wenao Ma, Jinpeng Li, Zhao Wang, Jingyang Zhang, Pheng-Ann Heng, and Qi Dou. Alt text

Details

Despite that the segment anything model (SAM) achieved impressive results on general-purpose semantic segmentation with strong generalization ability on daily images, its demonstrated performance on medical image segmentation is less precise and not stable, especially when dealing with tumor segmentation tasks that involve objects of small sizes, irregular shapes, and low contrast. Notably, the original SAM architecture is designed for 2D natural images, therefore would not be able to extract the 3D spatial information from volumetric medical data effectively. In this paper, we propose a novel adaptation method for transferring SAM from 2D to 3D for promptable medical image segmentation. Through a holistically designed scheme for architecture modification, we transfer the SAM to support volumetric inputs while retaining the majority of its pre-trained parameters for reuse. The fine-tuning process is conducted in a parameter-efficient manner, wherein most of the pre-trained parameters remain frozen, and only a few lightweight spatial adapters are introduced and tuned. Regardless of the domain gap between natural and medical data and the disparity in the spatial arrangement between 2D and 3D, the transformer trained on natural images can effectively capture the spatial patterns present in volumetric medical images with only lightweight adaptations. We conduct experiments on four open-source tumor segmentation datasets, and with a single click prompt, our model can outperform domain state-of-the-art medical image segmentation models on 3 out of 4 tasks, specifically by 8.25%, 29.87%, and 10.11% for kidney tumor, pancreas tumor, colon cancer segmentation, and achieve similar performance for liver tumor segmentation. We also compare our adaptation method with existing popular adapters, and observed significant performance improvement on most datasets.

Datasets

Alt text We use the 4 open-source datasets for training and evaluation our model.

Sample Results

Alt text

Get Started

Main Requirements

  • python=3.9.16
  • cuda=11.3
  • torch==1.12.1
  • torchvision=0.13.1

Installation

We suggest using Anaconda to setup environment on Linux, if you have installed anaconda, you can skip this step.

wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh && zsh Anaconda3-2020.11-Linux-x86_64.sh

Then, we can create environment and install packages using provided requirements.txt

conda create -n med_sam python=3.9.16
conda activate med_sam
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install git+https://github.com/deepmind/surface-distance.git
pip install -r requirements.txt

Our implementation is based on single GPU setting (NVIDIA A40 GPU), but can be easily adapted to use multiple GPUs. We need about 35GB of memory to run.

3DSAM-adapter (Ours)

To use the code, first go to the folder 3DSAM-adapter

cd 3DSAM-adapter

Type the command below to train the 3DSAM-adapter:

python train.py --data kits --snapshot_path "path/to/snapshot/" --data_prefix "path/to/data folder/" 

The pre-trained weight of SAM-B can be downloaded here and shall be put under the folder ckpt. Users with powerful GPUs can also adapt the model with SAM-L or SAM-H.

Type the command below to evaluate the 3DSAM-adapter:

python test.py --data kits --snapshot_path "path/to/snapshot/" --data_prefix "path/to/data folder/"  --num_prompts 1

Using --num_prompts to indicate the number of points used as prompt, the default value is 1.

Our pretrained checkpoint can be downloaded through OneDrive. For all four datasets, the crop size is 128.

Baselines

We provide our implementation for baselines includes

To use the code, first go to the folder baselines

cd baselines

Type the command below to train the baselines:

python train.py --data kits -m swin_unetr --snapshot_path "path/to/snapshot/" --data_prefix "path/to/data folder/"

Using --data to indicate the dataset, can be one of kits, pancreas, lits, colon

Using -m to indicate the method, can be one of swin_unetr, unetr, 3d_uxnet, nnformer, unetr++, transbts

For training Swin-UNETR, download the checkpoint and put it under the folder ckpt.

We use various hyper-parameters for each dataset, for more details, please refer to datasets.py. The crop size is set as (64, 160, 160) for all datasets.

Type the command below to evaluate the performance baselines:

python test.py --data kits -m swin_unetr --snapshot_path "path/to/snapshot/" --data_prefix "path/to/data folder/"

Feedback and Contact

For any questions, please contact [email protected]

Acknowledgement

Our code is based on Segment-Anything, 3D UX-Net, and Swin UNETR.

Citation

If you find this code useful, please cite in your research papers.

@article{Gong20233DSAMadapterHA,
  title={3DSAM-adapter: Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation},
  author={Shizhan Gong and Yuan Zhong and Wenao Ma and Jinpeng Li and Zhao Wang and Jingyang Zhang and Pheng-Ann Heng and Qi Dou},
  journal={arXiv preprint arXiv:2306.13465},
  year={2023}
}

More Repositories

1

FedBN

[ICLR'21] FedBN: Federated Learning on Non-IID Features via Local Batch Normalization
Python
227
star
2

EndoNeRF

[MICCAI'22] Neural Rendering for Stereo 3D Reconstruction of Deformable Tissues in Robotic Surgery
Python
147
star
3

SurRoL

[IROS'21] SurRoL: An Open-source Reinforcement Learning Centered and dVRK Compatible Platform for Surgical Robot Learning
Python
129
star
4

Endo-FM

[MICCAI'23] Foundation Model for Endoscopy Video Analysis via Large-scale Self-supervised Pre-train
Python
124
star
5

HarmoFL

[AAAI'22] HarmoFL: Harmonizing Local and Global Drifts in Federated Learning on Heterogeneous Medical Images
Python
79
star
6

Contrastive-COVIDNet

[IEEE JBHI'20] Contrastive Cross-site Learning with Redesigned Net for COVID-19 CT Classification
Python
55
star
7

FL-COVID

[npj Digital Medicine'21] Federated deep learning for detecting COVID-19 lung abnormalities in CT: a privacy-preserving multinational validation study. (Nature publishing group)
Python
37
star
8

DEX

[ICRA'23] Demonstration-Guided Reinforcement Learning with Efficient Exploration for Task Automation of Surgical Robot
Python
32
star
9

imFedSemi

[MICCAI'22] Dynamic Bank Learning for Semi-supervised Federated Image Diagnosis with Class Imbalance
Python
19
star
10

DLTTA

[IEEE TMI'22] DLTTA: Dynamic Learning Rate for Test-time Adaptation on Cross-domain Medical Images
Python
12
star
11

DiffusionMLS

[IPMI'23] Diffusion Model based Semi-supervised Learning on Brain Hemorrhage Images for Efficient Midline Shift Quantification
Python
12
star
12

GazeMedSeg

[MICCAI'24] Weakly-supervised Medical Image Segmentation with Gaze Annotations
Python
11
star
13

ViSkill

[IROS'23] Value-Informed Skill Chaining for Policy Learning of Long-Horizon Tasks with Surgical Robot
Python
11
star
14

Client-DP-FL

[MICCAI2023] Client-Level Differential Privacy via Adaptive Intermediary in Federated Medical Imaging
Python
10
star
15

TTADC

[MICCAI'22] Test-time Adaptation with Calibration of Medical Image Classification Nets for Label Distribution Shift
Python
9
star
16

HeteroPFL

[ICLR'24] Heterogeneous Personalized Federated Learning by Local-Global Updates Mixing via Convergence Rate
Python
9
star
17

IOP-FL

[IEEE TMI'23] IOP-FL: Inside-Outside Personalization for Federated Medical Image Segmentation
Python
8
star
18

AI-Endo

Code repository of AI-Endo
Python
6
star
19

PICG2scoring

[MICCAI'24] Incorporating Clinical Guidelines through Adapting Multi-modal Large Language Model for Prostate Cancer PI-RADS Scoring
Python
5
star
20

RecoverTissueDeform

[IEEE TMI'24] Self-Supervised Cyclic Diffeomorphic Mapping for Soft Tissue Deformation Recovery in Robotic Surgery Scenes
Jupyter Notebook
4
star
21

ICHSeg

[ISBI'24] Segmentation of Tiny Intracranial Hemorrhage via Learning-to-Rank Local Feature Enhancement
Python
2
star
22

Efficient-AI-tool-for-liver-fibrosis-staging

Python
2
star
23

TraumaDet

[MICCAI'24] Language-Enhanced Local-Global Aggregation Network for Multi-Organ Trauma Detection
Python
1
star