• Stars
    star
    313
  • Rank 133,714 (Top 3 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 1 year ago
  • Updated 7 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[MICCAI 2023] MedNeXt is a fully ConvNeXt architecture for 3D medical image segmentation.

MedNeXt

Copyright © German Cancer Research Center (DKFZ), Division of Medical Image Computing (MIC). Please make sure that your usage of this code is in compliance with the code license.

MedNeXt is a fully ConvNeXt architecture for 3D medical image segmentation designed to leverage the scalability of the ConvNeXt block while being customized to the challenges of sparsely annotated medical image segmentation datasets. MedNeXt is a model under development and is expected to be updated periodically in the near future.

The current training framework is built on top of nnUNet (v1) - the module name nnunet_mednext reflects this. You are free to adopt the architecture for your own training pipeline or use the one in this repository. Instructions are provided for both paths.

Please cite the following work if you find this model useful for your research:

Roy, S., Koehler, G., Ulrich, C., Baumgartner, M., Petersen, J., Isensee, F., Jaeger, P.F. & Maier-Hein, K.(2023). 
MedNeXt: Transformer-driven Scaling of ConvNets for Medical Image Segmentation. arXiv preprint arXiv:2303.09975.

Please also cite the following work if you use this pipeline for training:

Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2020). 
nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature Methods, 1-9.

Table of Contents

Current Versions and notable features:

  • v1 (MICCAI 2023): Fully 3D ConvNeXt architecture, residual ConvNeXt resampling, UpKern for large kernels, gradient checkpointing for training large models

As mentioned earlier, MedNeXt is actively under development and further improvements to the pipeline as future versions are anticipated.

Installation

The repository can be cloned and installed using the following commands.

git clone https://github.com/MIC-DKFZ/MedNeXt.git mednext
cd mednext
pip install -e .

MedNeXt Architecture and Usage in external pipelines

MedNeXt is usable on external training pipeline for 3D volumetric segmentation, similar to any PyTorch nn.Module. It is functionally decoupled from nnUNet when used simply as an architecture. It is sufficient to install the repository and import either the architecture or the block. In theory, it is possible to freely customize the network using MedNeXt both as an encoder-decoder style network as well as a block.

MedNeXt v1

MedNeXt v1 is the first version of the MedNeXt and incorporates the architectural features described here.

Important: MedNeXt v1 was trained with 1.0mm isotropic spacing as favored by architectures like UNETR, SwinUNETR and the usage of alternate spacing, like median spacing favored by native nnUNet, while perfectly usable in theory, is currently untested with MedNeXt v1 and may affect performance.

The usage as whole MedNeXt v1 as a complete architecture as well as the use of MedNeXt blocks (in external architectures, for example) is described below.

Usage as whole MedNeXt v1 architecture:

The architecture can be imported as follows with a number of arguments.

from nnunet_mednext.mednextv1 import MedNeXt

model = MedNeXt(
          in_channels: int,                         # input channels
          n_channels: int,                          # number of base channels
          n_classes: int,                           # number of classes
          exp_r: int = 4,                           # Expansion ratio in Expansion Layer
          kernel_size: int = 7,                     # Kernel Size in Depthwise Conv. Layer
          enc_kernel_size: int = None,              # (Separate) Kernel Size in Encoder
          dec_kernel_size: int = None,              # (Separate) Kernel Size in Decoder
          deep_supervision: bool = False,           # Enable Deep Supervision
          do_res: bool = False,                     # Residual connection in MedNeXt block
          do_res_up_down: bool = False,             # Residual conn. in Resampling blocks
          checkpoint_style: bool = None,            # Enable Gradient Checkpointing
          block_counts: list = [2,2,2,2,2,2,2,2,2], # Depth-first no. of blocks per layer 
          norm_type = 'group',                      # Type of Norm: 'group' or 'layer'
          dim = '3d'                                # Supports `3d', '2d' arguments
)

Please note that - 1) Deep Supervision, and 2) residual connections in both MedNeXt and Up/Downsampling blocks are both used in the publication for training.

Gradient Checkpointing can be used to train larger models in low memory devices by trading compute for activation storage. The checkpointing implemented in this version is at the MedNeXt block level.

MedNeXt v1 has been tested with 4 defined architecture sizes and 2 defined kernel sizes. Their particulars are as follows:

Name (Model ID) Kernel Size Parameters GFlops
Small (S) 3x3x3 5.6M 130
Small (S) 5x5x5 5.9M 169
Base (B) 3x3x3 10.5M 170
Base (B) 5x5x5 11.0M 208
Medium (M) 3x3x3 17.6M 248
Medium (M) 5x5x5 18.3M 308
Large (L) 3x3x3 61.8M 500
Large (L) 5x5x5 63.0M 564

Utility functions have been defined for re-creating these architectures (with or without deep supervision) as follows customized to input channels, number of target classes, model IDs as used in the publication, kernel size and deep supervision:

from nnunet_mednext import create_mednext_v1

model = create_mednext_v1(
  num_channels = 3,
  num_classes = 10,
  model_id = 'B',             # S, B, M and L are valid model ids
  kernel_size = 3,            # 3x3x3 and 5x5x5 were tested in publication
  deep_supervision = True     # was used in publication
)

Individual Usage of MedNeXt blocks

MedNeXt blocks can be imported for use individually similar to the entire architecture. The following blocks can be imported directed for use.

from nnunet_mednext import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock

# Standard MedNeXt block
block = MedNeXtBlock(
    in_channels:int,              # no. of input channels
    out_channels:int,             # no. of output channels
    exp_r:int=4,                  # channel expansion ratio in Expansion Layer
    kernel_size:int=7,            # kernel size in Depthwise Conv. Layer
    do_res:bool=True,              # residual connection on or off. Default: True
    norm_type:str = 'group',      # type of norm: 'group' or 'layer'
    n_groups:int or None = None,  # no. of groups in Depthwise Conv. Layer
                                  # (keep 'None' in most cases)
)


# 2x Downsampling with MedNeXt block
block_down = MedNeXtDownBlock(
    in_channels:int,              # no. of input channels
    out_channels:int,             # no. of output channels
    exp_r:int=4,                  # channel expansion ratio in Expansion Layer
    kernel_size:int=7,            # kernel size in Depthwise Conv. Layer
    do_res:bool=True,              # residual connection on or off. Default: True
    norm_type:str = 'group',      # type of norm: 'group' or 'layer'
)


# 2x Upsampling with MedNeXt block
block_up = MedNeXtUpBlock(
    in_channels:int,              # no. of input channels
    out_channels:int,             # no. of output channels
    exp_r:int=4,                  # channel expansion ratio in Expansion Layer
    kernel_size:int=7,            # kernel size in Depthwise Conv. Layer
    do_res:bool=True,              # residual connection on or off. Default: True
    norm_type:str = 'group',      # type of norm: 'group' or 'layer'
)

UpKern weight loading

UpKern is a simple algorithm for initializing a large kernel MedNeXt network with an equivalent small kernel MedNeXt. Equivalent refers to a network of the same configuration with the only difference being kernel size in the Depthwise Convolution layers. Large kernels are initialized by trilinear interpolation of their smaller counterparts. The following is an example of using this weight loading style.

from nnunet_mednext import create_mednext_v1
from nnunet_mednext.run.load_weights import upkern_load_weights
m_net_ = create_mednext_v1(1, 3, 'S', 5)
m_pre = create_mednext_v1(1, 3, 'S', 3)

# Generally m_pre would be pretrained
m3 = upkern_load_weights(m_net_, m_pre)

Usage of internal training pipeline

Plan and Preprocess

To preprocess your datasets as in the MICCAI 2023 version, please run

mednextv1_plan_and_preprocess -t YOUR_TASK -pl3d ExperimentPlanner3D_v21_customTargetSpacing_1x1x1 -pl2d ExperimentPlanner2D_v21_customTargetSpacing_1x1x1

As in nnUNet, you can set -pl3d or -pl2d as None if you do not require preprocessed data in those dimensions. Please note that YOUR_TASK in this repo is designed to be in the old nnUNet(v1) format. If you want to use the latest nnUNet (v2), you will have to adopt the preprocessor on your own.

The custom ExperimentPlanner3D_v21_customTargetSpacing_1x1x1 is designed to set patch size to 128x128x128 and spacing to 1mm isotropic since those are the experimental conditions used in the MICCAI 2023 version.

Train MedNeXt using nnUNet (v1) training

MedNeXt has custom nnUNet (v1) trainers that allow it to be trained similar to the base architecture. Please check the old nnUNet(v1) branch in the nnUNet repo, if you are unfamiliar with this code format. Please look here for all available trainers to recreate the MICCAI 2023 experiments. Please note that all trainers are in 3D since the architecture was tested in 3D. You can of course, create your custom trainers if you want (including 2D trainers for 2D architectures).

mednextv1_train 3d_fullres TRAINER TASK_NUMBER FOLD -p nnUNetPlansv2.1_trgSp_1x1x1

There are trainers for 4 architectures (S, B, M, L) and 2 kernel sizes (3, 5) to replicate the experiments from MICCAI 2023. The following is an example for training an nnUNetTrainerV2_MedNeXt_S_kernel3 trainer on the task Task040_KiTS2019 on fold 0 of the 5-folds split generated by nnUNet's data preprocessor.

mednextv1_train 3d_fullres nnUNetTrainerV2_MedNeXt_S_kernel3 Task040_KiTS2019 0 -p nnUNetPlansv2.1_trgSp_1x1x1

A kernel 5x5x5 version from scratch can also be trained this way, although we recommend initially training a kernel 3x3x3 version and using UpKern.

Train a kernel 5x5x5 version using UpKern

To train a kernel 5x5x5 version using UpKern, a kernel 3x3x3 version must already be trained. To train using UpKern, simply run the following:

mednextv1_train 3d_fullres TRAINER TASK FOLD -p nnUNetPlansv2.1_trgSp_1x1x1 -pretrained_weights YOUR_MODEL_CHECKPOINT_FOR_KERNEL_3_FOR_SAME_TASK_AND_FOLD -resample_weights

The following is an example for training an nnUNetTrainerV2_MedNeXt_S_kernel5 trainer on the task Task040_KiTS2019 on fold 0 of the 5-folds split generated by nnUNet's data preprocessor by using UpKern.

mednextv1_train 3d_fullres nnUNetTrainerV2_MedNeXt_S_kernel5 Task040_KiTS2019 0 -p nnUNetPlansv2.1_trgSp_1x1x1 -pretrained_weights SOME_PATH/nnUNet/3d_fullres/Task040_KiTS2019/nnUNetTrainerV2_MedNeXt_S_kernel3__nnUNetPlansv2.1_trgSp_1x1x1/fold_0/model_final_checkpoint.model -resample_weights

The -resample_weights flag as it is responsible to triggering the UpKern algorithm.

A note on 2D MedNeXt:

Please note that while the MedNeXt can run on 2D, it has not been tested in 2D mode.

More Repositories

1

nnUNet

Python
5,539
star
2

medicaldetectiontoolkit

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.
Python
1,287
star
3

batchgenerators

A framework for data augmentation for 2D and 3D image classification and segmentation
Jupyter Notebook
1,077
star
4

nnDetection

nnDetection is a self-configuring framework for 3D (volumetric) medical object detection which can be applied to new data sets without manual intervention. It includes guides for 12 data sets that were used to develop and evaluate the performance of the proposed method.
Python
536
star
5

HD-BET

MRI brain extraction tool
Python
262
star
6

TractSeg

Automatic White Matter Bundle Segmentation
Python
222
star
7

napari-sam

Python
220
star
8

trixi

Manage your machine learning experiments with trixi - modular, reproducible, high fashion. An experiment infrastructure optimized for PyTorch, but flexible enough to work for your framework and your tastes.
Python
219
star
9

basic_unet_example

An example project of how to use a U-Net for segmentation on medical images with PyTorch.
Python
137
star
10

MITK-Diffusion

MITK Diffusion - Official part of the Medical Imaging Interaction Toolkit
C++
76
star
11

LIDC-IDRI-processing

Scripts for the preprocessing of LIDC-IDRI data
Python
75
star
12

BraTS2017

Python
74
star
13

BodyPartRegression

Python
62
star
14

dynamic-network-architectures

Python
61
star
15

mood

Repository for the Medical Out-of-Distribution Analysis Challenge.
Python
60
star
16

ACDC2017

Python
54
star
17

niicat

This is a tool to quickly preview nifti images on the terminal
Python
51
star
18

RegRCNN

This repository holds the code framework used in the paper Reg R-CNN: Lesion Detection and Grading under Noisy Labels. It is a fork of MIC-DKFZ/medicaldetectiontoolkit with regression capabilites.
Python
51
star
19

Skeleton-Recall

Skeleton Recall Loss for Connectivity Conserving and Resource Efficient Segmentation of Thin Tubular Structures
Python
47
star
20

MultiTalent

Implemention of the Paper "MultiTalent: A Multi-Dataset Approach to Medical Image Segmentation"
Python
46
star
21

image_classification

🎯 Deep Learning Framework for Image Classification & Regression in Pytorch for Fast Experiments
Python
42
star
22

RTTB

Swiss army knife for radiotherapy analysis
C++
26
star
23

vae-anomaly-experiments

Python
26
star
24

Hyppopy

Hyppopy is a python toolbox for blackbox optimization. It's purpose is to offer a unified and easy to use interface to a collection of solver libraries.
Python
25
star
25

patchly

A grid sampler for larger-than-memory N-dimensional images
Python
23
star
26

semantic_segmentation

Python
23
star
27

probabilistic_unet

A U-Net combined with a variational auto-encoder that is able to learn conditional distributions over semantic segmentations.
Jupyter Notebook
22
star
28

image-time-series

Code for deep learning-based glioma/tumor growth models
Python
21
star
29

anatomy_informed_DA

Python
18
star
30

batchgeneratorsv2

Python
13
star
31

foundation-models-for-cbmir

Python
12
star
32

MedVol

Python
12
star
33

ParticleSeg3D

Python
10
star
34

generalized_yolov5

An extension of YOLOv5 to non-natural images together with 5-Fold Cross-Validation
Python
8
star
35

radtract

RadTract: enhanced tractometry with radiomics-based imaging biomarkers for improved predictive modelling.
Python
8
star
36

gpconvcnp

Code for "GP-ConvCNP: Better Generalization for Convolutional Conditional Neural Processes on Time Series Data"
Python
8
star
37

cmdint

CmdInterface enables detailed logging of command line and python experiments in a very lightweight manner (coding wise). It wraps your command line or python function calls in a few lines of python code and logs everything you might need to reproduce the experiment later on or to simply check what you did a couple of years ago.
Python
8
star
38

acvl_utils

Python
7
star
39

MurineAirwaySegmentation

Python
7
star
40

cOOpD

Python
7
star
41

PROUNET

Prostate U-net
Python
7
star
42

napari-nifti

Python
4
star
43

agent-sam

Segment Anything model wrapper used by the Medical Imaging Interaction Toolkit (MITK).
Python
4
star
44

OverthINKingSegmenter

Python
3
star
45

perovskite-xai

Python
3
star
46

help_a_hematologist_out_challenge

Python
2
star
47

AGGC2022

Automated Gleason Grading on WSI
Python
2
star
48

tqdmp

Multiprocessing with tqdm progressbars!
Python
2
star
49

MatchPoint

MatchPoint is a translational image registration framework written in C++. It offers a standardized interface to utilize several registration algorithm resources (like ITK, plastimatch, elastix) easily in a host application.
C++
2
star
50

napari-mzarr

Python
2
star
51

n2c2-challenge-2019

Jupyter Notebook
2
star
52

mzarr

Python
1
star
53

imlh-icml-detection-tools

Python
1
star
54

napari-blosc2

Python
1
star
55

BraTPRO

Python
1
star