ΟΞ±ΠΌΞ± mβοΈdules
*A PyTorch Computer Vision (CV) module library for building n-D networks flexibly ~*
(Simple-to-use & Function-rich!)
Highlights
- Simple code and detailed annotation
- Easy and flexible to integrate with other codes
- Support 1D / 2D / 3D networks (for 1D signal, 2D image, 3D video, 3D medical image...)
- Abundant structure: CNN, GNN, Transformer
- Friendly demo codes
- π Abundant Pretrained weights: including 80000+
2D weights
and 80+3D weights
Pretrained weights:
[Google Drive]
[Baidu Netdisk psw: wama
]
Quick start with demo codes of 6 different novel multi-label network structures
all model codes are re-constructed in a very simple way
Network | Publication | Demo code | Paper link | Support multi-class per label |
---|---|---|---|---|
CNNRNN | CVPR2016 | code | link | Γ |
ML-GCN | CVPR2019 | code | link | Γ |
SSGRL | ICCV2019 | code | link | β |
C-tran | CVPR2021 | code | link | β |
ML-decoder | arxiv2021 | code | link | β |
Q2L | arxiv2021 | code | link | β |
1. Installation
π₯ wama_modules
Basic
1D/2D/3D
Install wama_modules
with command β
pip install git+https://github.com/WAMAWAMA/wama_modules.git
Other ways to install (or use) wama_modules
- Way1: Download code and run
python setup.py install
- Way2: Directly copy the folder
wama_modules
into your project path
π§ segmentation_models_pytorch
Optional
2D
100+ pretrained weights
Introduction and installation command
segmentation_models_pytorch
(called smp
)
is a 2D CNN lib including many backbones and decoders, which is highly recommended to install for cooperating with this library.
*Our codes have already contained smp
, but you can still install the latest version with the code below.
Install with pip β
pip install segmentation-models-pytorch
Install the latest version β
pip install git+https://github.com/rwightman/pytorch-image-models.git
Optional
2D
80000+ pretrained weights
Introduction and installation command
transformer
(powered by Huggingface) is a lib including super abundant CNN and Transformer structures, which is highly recommended to install for cooperating with this library.
Install transformer
with pip β
pip install transformers
π§ timm
Optional
2D
400+ pretrained weights
Introduction and installation command
timm
is a lib includes abundant CNN and Transformer structures, which is highly recommended to install for cooperating with this library.
Install with pip β
pip install timm
Install the latest version β
pip install git+https://github.com/rwightman/pytorch-image-models.git
2. Update list
- 2022/11/11: The birthday of this code, version
v0.0.1
- 2023/01/23: Add demo code of 6 multi-label network structures (Happy Chinese New Year
π ) - ...
3. Main modules and network architectures
An overview of this repo (let's call wama_modules
as wm
)
File | Description | Main class or function |
---|---|---|
wm.utils |
Some operations on tensors and pre-training weights | resizeTensor() tensor2array() load_weights() |
wm.thirdparty_lib |
2D/3D network structures (CNN/GNN/Transformer) from other repositories, and all are with pre-trained weights |
MedicalNet C3D 3D_DenseNet 3D_shufflenet transformers.ConvNextModel transformers.SwinModel Radimagenet |
wm.Attention |
Some attention-based plugins | SCSEModule NonLocal |
wm.BaseModule |
Basic modules(layers). For example, BottleNeck block (ResBlock) in ResNet, and DenseBlock in DenseNet, etc. | MakeNorm() MakeConv() MakeActive() VGGBlock ResBlock DenseBlock |
wm.Encoder |
Some encoders such like ResNet or DenseNet, but with more flexibility for building the network modularly, and 1D/2D/3D are all supported | VGGEncoder ResNetEncoder DenseNetEncoder |
wm.Decoder |
Some encoders with more flexibility for building the network modularly, and 1D/2D/3D are all supported | UNet_decoder |
wm.Neck |
Modules for making the multi-scale features (from encoder) interact with each other to generate stronger features | FPN |
wm.Transformer |
Some self-attention or cross-attention modules, which can be used to build ViT, DETR or TransUnet | TransformerEncoderLayer TransformerDecoderLayer |
wm.PositionEmbedding |
Some position embedding modules, such as the learnable embedding, or 1/2/3D sincos embedding in MAE or Vit | PositionalEncoding_1D_sincos PositionalEncoding_2D_sincos PositionalEncoding_3D_sincos |
- How to build your networks modularly and freely?
π See 'Guideline 1: Build networks modularly' below ~ - How to use pretrained model with
wm.thirdparty_lib
? π See 'Guideline 2: Use pretrained weights' below ~
4. Guideline 1: Build networks modularly
How to build a network modularly? Here's a paradigm β
'Design architecture according to tasks, pick modules according to architecture'
So, network architectures for different tasks can be viewed modularly such as:
- VGG for classification
=
VGG_encoder+
classification_head - ResNet for classification
=
ResNet_encoder+
classification_head - Unet for segmentation
=
encoder+
decoder+
segmentation_head - A multi-task net for classification and segmentation
=
encoder+
decoder+
cls_head+
seg_head
For example, build a 3D resnet50 encoder to output multi-scale feature maps β
from wama_modules.Encoder import ResNetEncoder
import torch
dim = 3 # input is 3D volume
in_channels = 3
input = torch.ones([2, in_channels, 64, 64, 64])
encoder = ResNetEncoder(
in_channels,
stage_output_channels=[64, 128, 256],
blocks=[6, 12, 24],
downsample_ration=[0.5, 0.5, 0.5], # set your downsampling speed
dim=dim
)
multi_scale_f = encoder(input)
_ = [print(i.shape) for i in multi_scale_f]
# --------------------------------
# output π
# torch.Size([2, 64, 15, 15, 15])
# torch.Size([2, 128, 7, 7, 7])
# torch.Size([2, 256, 3, 3, 3])
# --------------------------------
Here are more demos shown below β (Click to view codes, or visit the demo
folder)
Demo0: Build a 3D VGG for Single Label Classification
import torch
import torch.nn as nn
from wama_modules.Encoder import VGGEncoder
from wama_modules.Head import ClassificationHead
from wama_modules.BaseModule import GlobalMaxPool
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
f_channel_list = [64, 128, 256, 512]
self.encoder = VGGEncoder(
in_channel,
stage_output_channels=f_channel_list,
blocks=[1, 2, 3, 4],
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# cls head
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
self.pooling = GlobalMaxPool()
def forward(self, x):
f = self.encoder(x)
logits = self.cls_head(self.pooling(f[-1]))
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(is_malignant=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('single-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# output π
# single-label predicted logits
# logits of is_malignant : torch.Size([2, 4])
Demo1: Build a 3D ResNet for Single Label Classification
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Head import ClassificationHead
from wama_modules.BaseModule import GlobalMaxPool
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=f_channel_list,
stage_middle_channels=f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# cls head
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
self.pooling = GlobalMaxPool()
def forward(self, x):
f = self.encoder(x)
logits = self.cls_head(self.pooling(f[-1]))
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(is_malignant=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('single-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# output π
# single-label predicted logits
# logits of is_malignant : torch.Size([2, 4])
Demo2: Build a ResNet for Multi-Label Classification
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Head import ClassificationHead
from wama_modules.BaseModule import GlobalMaxPool
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=f_channel_list,
stage_middle_channels=f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# cls head
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
self.pooling = GlobalMaxPool()
def forward(self, x):
f = self.encoder(x)
logits = self.cls_head(self.pooling(f[-1]))
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(shape=4, color=3, other=13)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi_label predicted logits
# logits of shape : torch.Size([2, 4])
# logits of color : torch.Size([2, 3])
# logits of other : torch.Size([2, 13])
Demo3: Build a ResNetUnet for Single Label Segmentation
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
multi_scale_f1 = self.encoder(x)
multi_scale_f2 = self.decoder(multi_scale_f1)
f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
Demo4: Build a ResNetUnet for Multi-Label Segmentation
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
multi_scale_f1 = self.encoder(x)
multi_scale_f2 = self.decoder(multi_scale_f1)
f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])
Demo5: Build a MultiTask net for Segmentation and Classfification
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead, ClassificationHead
from wama_modules.utils import resizeTensor
from wama_modules.BaseModule import GlobalMaxPool
class Model(nn.Module):
def __init__(self,
in_channel,
seg_label_category_dict,
cls_label_category_dict,
dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
seg_label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
# cls head
self.cls_head = ClassificationHead(cls_label_category_dict, Encoder_f_channel_list[-1])
# pooling
self.pooling = GlobalMaxPool()
def forward(self, x):
# get encoder features
multi_scale_encoder = self.encoder(x)
# get decoder features
multi_scale_decoder = self.decoder(multi_scale_encoder)
# perform segmentation
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
seg_logits = self.seg_head(f_for_seg)
# perform classification
cls_logits = self.cls_head(self.pooling(multi_scale_encoder[-1]))
return seg_logits, cls_logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
seg_label_category_dict = dict(organ=3, tumor=2)
cls_label_category_dict = dict(shape=4, color=3, other=13)
model = Model(
in_channel=1,
cls_label_category_dict=cls_label_category_dict,
seg_label_category_dict=seg_label_category_dict,
dim=3)
seg_logits, cls_logits = model(x)
print('multi_label predicted logits')
_ = [print('seg logits of ', key, ':', seg_logits[key].shape) for key in seg_logits.keys()]
print('-'*30)
_ = [print('cls logits of ', key, ':', cls_logits[key].shape) for key in cls_logits.keys()]
# out
# multi_label predicted logits
# seg logits of organ : torch.Size([2, 3, 128, 128, 128])
# seg logits of tumor : torch.Size([2, 2, 128, 128, 128])
# ------------------------------
# cls logits of shape : torch.Size([2, 4])
# cls logits of color : torch.Size([2, 3])
# cls logits of other : torch.Size([2, 13])
Demo6: Build a Unet with a resnet encoder and a FPN neck
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
from wama_modules.Neck import FPN
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# neck
FPN_output_channel = 256
FPN_channels = [FPN_output_channel]*len(Encoder_f_channel_list)
self.neck = FPN(in_channels_list=Encoder_f_channel_list,
c1=FPN_output_channel//2,
c2=FPN_output_channel,
mode='AddSmall2Big',
dim=dim,)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=FPN_channels,
skip_connection=[True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
multi_scale_encoder = self.encoder(x)
multi_scale_neck = self.neck(multi_scale_encoder)
multi_scale_decoder = self.decoder(multi_scale_neck)
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])
Demo7: Build a 2D TransUnet for Segmentation
From paper : TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation
Proposed by Jieneng Chen
[paper]
[official code]
[Structure of TransUnet]
Demo code π
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
from transformers import ViTModel
from wama_modules.utils import load_weights, tmp_class
class TransUNet(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# neck
neck_out_channel = 768
transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
configuration = transformer.config
self.trans_downsample_size = configuration.image_size = [8, 8]
configuration.patch_size = [1, 1]
configuration.num_channels = Encoder_f_channel_list[-1]
configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used
self.neck = ViTModel(configuration, add_pooling_layer=False)
pretrained_weights = transformer.state_dict()
pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[
'embeddings.position_embeddings']
pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.weight']
pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.bias']
self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel],
skip_connection=[True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
# encoder forward
multi_scale_encoder = self.encoder(x)
# neck forward
f_neck = self.neck(resizeTensor(multi_scale_encoder[-1], size=self.trans_downsample_size))
f_neck = f_neck.last_hidden_state
f_neck = f_neck[:, 1:] # remove class token
f_neck = f_neck.permute(0, 2, 1)
f_neck = f_neck.reshape(
f_neck.shape[0],
f_neck.shape[1],
self.trans_downsample_size[0],
self.trans_downsample_size[1]
) # reshape
f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:])
multi_scale_encoder[-1] = f_neck
# decoder forward
multi_scale_decoder = self.decoder(multi_scale_encoder)
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
# seg_head forward
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 256, 256])
label_category_dict = dict(organ=3, tumor=4)
model = TransUNet(in_channel=1, label_category_dict=label_category_dict, dim=2)
with torch.no_grad():
logits = model(x)
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 256, 256])
# logits of tumor : torch.Size([2, 4, 256, 256])
Demo8: Build a 3D TransUnet for Segmentation
*Original TransUnet only recieves 2D input.
So if we want to build a 3D TransUnet, with the tensor.reshape
operation in torch,
we can temporarily convert 3D featuremap to 2D featuremap in the middle process,
and then convert it back to 3D featuremap. You can find this process in the code of neck forward
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
from transformers import ViTModel
from wama_modules.utils import load_weights, tmp_class
class TransUnet(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# neck
neck_out_channel = 768
transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
configuration = transformer.config
self.trans_size_3D = [8, 8, 4]
self.trans_size = configuration.image_size = [
self.trans_size_3D[0], self.trans_size_3D[1]*self.trans_size_3D[2]
]
configuration.patch_size = [1, 1]
configuration.num_channels = Encoder_f_channel_list[-1]
configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used
self.neck = ViTModel(configuration, add_pooling_layer=False)
pretrained_weights = transformer.state_dict()
pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[
'embeddings.position_embeddings']
pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.weight']
pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.bias']
self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel],
skip_connection=[True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
# encoder forward
multi_scale_encoder = self.encoder(x)
# neck forward
neck_input = resizeTensor(multi_scale_encoder[-1], size=self.trans_size_3D)
neck_input = neck_input.reshape(neck_input.shape[0], neck_input.shape[1], *self.trans_size) # 3D to 2D
f_neck = self.neck(neck_input)
f_neck = f_neck.last_hidden_state
f_neck = f_neck[:, 1:] # remove class token
f_neck = f_neck.permute(0, 2, 1)
f_neck = f_neck.reshape(
f_neck.shape[0],
f_neck.shape[1],
self.trans_size[0],
self.trans_size[1]
) # reshape
f_neck = f_neck.reshape(f_neck.shape[0], f_neck.shape[1], *self.trans_size_3D) # 2D to 3D
f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:])
multi_scale_encoder[-1] = f_neck
# decoder forward
multi_scale_decoder = self.decoder(multi_scale_encoder)
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
# seg_head forward
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 96])
label_category_dict = dict(organ=3, tumor=4)
model = TransUnet(in_channel=1, label_category_dict=label_category_dict, dim=3)
with torch.no_grad():
logits = model(x)
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 96])
# logits of tumor : torch.Size([2, 4, 128, 128, 96])
Demo: Multi-label network structure π’
6 different novel multi-label network structures
all model codes are re-constructed in a very simple way
Network | Publication | Demo code | Paper link | Support multi-class per label |
---|---|---|---|---|
CNNRNN | CVPR2016 | code | link | Γ |
ML-GCN | CVPR2019 | code | link | Γ |
SSGRL | ICCV2019 | code | link | β |
C-tran | CVPR2021 | code | link | β |
ML-decoder | arxiv2021 | code | link | β |
Q2L | arxiv2021 | code | link | β |
*Todo-demo list (
Demo: Build a UCTransNet model for segmentation
Demo: Build a model for multiple inputs (1D signal and 2D image)
Demo: Build a 2D Unet with pretrained Resnet50 encoder (1D signal and 2D image)
Demo: Build a 3D DETR model for object detection
Demo: Build a 3D VGG with SE-attention module for multi-instanse classification
5. Guideline 2: Use pretrained weights
(*All pretrained weights are from third-party codes or repos)
Currently available pre-training models are shown below β
Module name | Number of pretrained weights | Pretrained data | Dimension | |
---|---|---|---|---|
1 | .ResNets3D_kenshohara |
21 | video | 3D |
2 | .VC3D_kenshohara |
13 | video | 3D |
3 | .Efficient3D_okankop |
39 | video | 3D |
4 | .MedicalNet_tencent |
11 | medical image | 3D |
5 | .C3D_jfzhang95 |
1 | video | 3D |
6 | .C3D_yyuanad |
1 | video | 3D |
7 | .SMP_qubvel |
119 | image | 2D |
8 | timm |
400+ | image | 2D |
9 | transformers |
80000+ | video/image | 2D/3D |
10 | radimagenet |
1 | medical image | 2D |
*Download all pretrained weights from
[Google Drive]
or [Baidu Netdisk psw: wama
]
21 weights
3D
5.1 ResNets3D_kenshohara ResNets3D_kenshohara (21 weights)
Demo code ---------------------------------
import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\kenshohara_ResNets3D_weights\resnet\r3d18_KM_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet2p1d import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\kenshohara_ResNets3D_weights\resnet2p1d\r2p1d18_K_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
13 weights
3D
5.2 VC3D_kenshohara VC3D_kenshohara (13 weights)
Demo code ---------------------------------
# resnet
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnet\resnet-18-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
# resnext
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnext import generate_model
from wama_modules.utils import load_weights
m = generate_model(101)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnext\resnext-101-64f-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
# wide_resnet
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.wide_resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model()
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\wideresnet\wideresnet-50-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
39 weights
3D
5.3 Efficient3D_okankop Efficient3D_okankop (39 weights)
Demo code ---------------------------------
# c3d
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.c3d import get_model
m = get_model() # c3d has no pretrained weights
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# mobilenet
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.mobilenet import get_model
from wama_modules.utils import load_weights
m = get_model(width_mult = 1.) # e.g. width_mult = 1 when mobilenet_1.0x
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\mobilenet\jester_mobilenet_1.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = get_model(width_mult = 2.) # e.g. width_mult = 2 when mobilenet_2.0x
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\mobilenet\jester_mobilenet_2.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# mobilenetv2
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.mobilenetv2 import get_model
from wama_modules.utils import load_weights
m = get_model(width_mult = 1.) # e.g. width_mult = 1 when mobilenet_1.0x
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\mobilenetv2\jester_mobilenetv2_1.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = get_model(width_mult = 0.45) # e.g. width_mult = 1 when mobilenet_1.0x
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\mobilenetv2\jester_mobilenetv2_0.45x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# resnet
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.resnet import resnet18, resnet50, resnet101
from wama_modules.utils import load_weights
m = resnet18()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\resnet\kinetics_resnet_18_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = resnet50()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\resnet\kinetics_resnet_50_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = resnet101()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\resnet\kinetics_resnet_101_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# resnext
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.resnext import resnext101
from wama_modules.utils import load_weights
m = resnext101()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\resnext\jester_resnext_101_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# shufflenet
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.shufflenet import get_model
from wama_modules.utils import load_weights
m = get_model(groups=3, width_mult=1)
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\shufflenet\jester_shufflenet_1.0x_G3_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = get_model(groups=3, width_mult=1.5)
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\shufflenet\jester_shufflenet_1.5x_G3_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# shufflenetv2
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.shufflenetv2 import get_model
from wama_modules.utils import load_weights
m = get_model(width_mult=1)
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\shufflenetv2\jester_shufflenetv2_1.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = get_model(width_mult=2)
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\shufflenetv2\jester_shufflenetv2_2.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# squeezenet
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.squeezenet import get_model
from wama_modules.utils import load_weights
m = get_model()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\squeezenet\jester_squeezenet_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
11 weights
3D
medical
5.4 MedicalNet_tencent MedicalNet_tencent (11 weights)
Demo code ---------------------------------
import torch
from wama_modules.utils import load_weights
from wama_modules.thirdparty_lib.MedicalNet_Tencent.model import generate_model
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\MedicalNet_Tencent\MedicalNet_weights\resnet_18_23dataset.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 1, 64, 64, 64])) # input channel is 1 (not 3 for video)
_ = [print(i.shape) for i in f_list]
1 weights
3D
5.5 C3D_jfzhang95 C3D_jfzhang95 (1 weight)
Demo code ---------------------------------
import torch
from wama_modules.utils import load_weights
from wama_modules.thirdparty_lib.C3D_jfzhang95.c3d import C3D
m = C3D()
pretrain_path = r"D:\pretrainedweights\C3D_jfzhang95\C3D_jfzhang95_weights\C3D_jfzhang95_C3D.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
1 weights
3D
5.6 C3D_yyuanad C3D_yyuanad (1 weight)
Demo code ---------------------------------
import torch
from wama_modules.utils import load_weights
from wama_modules.thirdparty_lib.C3D_yyuanad.c3d import C3D
m = C3D()
pretrain_path = r"D:\pretrainedweights\C3D_yyuanad\C3D_yyuanad_weights\C3D_yyuanad.pickle"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
119 weights
2D
5.7 SMP_qubvel SMP_qubvel (119 weights, Automatic online download)
Demo code ---------------------------------
import torch
from wama_modules.thirdparty_lib.SMP_qubvel.encoders import get_encoder
m = get_encoder('resnet18', in_channels=3, depth=5, weights='ssl')
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
400+ weights
2D
5.8 timm timm (400+ weights, Automatic online download)
Demo code ---------------------------------
import torch
import timm
m = timm.create_model(
'adv_inception_v3',
features_only=True,
pretrained=True,)
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
80000+ weights
2D
5.9 transformers transformers (80000+ weights, Automatic online download)
All models please go to Huggingface [ModelHub]
Demo code ---------------------------------
import torch
from transformers import ConvNextModel
from wama_modules.utils import load_weights
# Initializing a model (with random weights) from the convnext-tiny-224 style configuration
m = ConvNextModel.from_pretrained('facebook/convnext-base-224-22k')
f = m(torch.ones([2,3,224,224]), output_hidden_states=True)
f_list = f.hidden_states
_ = [print(i.shape) for i in f_list]
# reload weights
weights = m.state_dict()
m1 = ConvNextModel(m.config)
m = load_weights(m, weights)
import torch
from transformers import SwinModel
from wama_modules.utils import load_weights
m = SwinModel.from_pretrained('microsoft/swin-base-patch4-window12-384')
f = m(torch.ones([2,3,384,384]), output_hidden_states=True)
f_list = f.reshaped_hidden_states # For transformer, should use reshaped_hidden_states
_ = [print(i.shape) for i in f_list]
# reload weights
weights = m.state_dict()
m1 = SwinModel(m.config)
m = load_weights(m, weights)
1 weights
2D
medical
5.10 radimagenet radimagenet (1 weight)
Demo code ---------------------------------
import torch
from wama_modules.utils import load_weights
from wama_modules.thirdparty_lib.SMP_qubvel.encoders import get_encoder
m = get_encoder('resnet50', in_channels=3, depth=5, weights=None)
pretrain_path = r"D:\pretrainedweights\radimagnet\RadImageNet_models-20221104T172755Z-001\RadImageNet_models\RadImageNet-ResNet50_notop_torch.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
π₯°
6. Acknowledgment Thanks to these authors and their codes:
- https://github.com/ZhugeKongan/torch-template-for-deep-learning
- pytorch vit: https://github.com/lucidrains/vit-pytorch
- SMP: https://github.com/qubvel/segmentation_models.pytorch
- transformers: https://github.com/huggingface/transformers
- medicalnet: https://github.com/Tencent/MedicalNet
- timm: https://github.com/rwightman/pytorch-image-models
- ResNets3D_kenshohara: https://github.com/kenshohara/3D-ResNets-PyTorch
- VC3D_kenshohara: https://github.com/kenshohara/video-classification-3d-cnn-pytorch
- Efficient3D_okankop: https://github.com/okankop/Efficient-3DCNNs
- C3D_jfzhang95: https://github.com/jfzhang95/pytorch-video-recognition
- C3D_yyuanad: https://github.com/yyuanad/Pytorch_C3D_Feature_Extractor
- radimagenet: https://github.com/BMEII-AI/RadImageNet
- BMEII-AI/RadImageNet#3
- TransUnet: https://github.com/Beckschen/TransUNet
- ML-Decoder: https://github.com/Alibaba-MIIL/ML_Decoder
- Q2L: https://github.com/SlongLiu/query2labels
- https://github.com/AmrMaghraby/CNN-RNN-A-Unified-Framework-for-Multi-label-Image-Classification
- https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/image_captioning
- ML-GCN: code:https://github.com/megvii-research/ML-GCN
- C-tran: https://github.com/QData/C-Tran
- SSGRL: https://github.com/HCPLab-SYSU/SSGRL