• Stars
    star
    354
  • Rank 120,042 (Top 3 %)
  • Language
    Python
  • License
    GNU General Publi...
  • Created over 1 year ago
  • Updated 8 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Low rank adaptation for Vision Transformer

MeLo: Low-rank Adaptation is Better than Finetuning for Medical Image

Intro

Useful links

[Homepage]      [arXiv]     

Feature

  • Supported DeepLab segmentation for lukemelas/PyTorch-Pretrained-ViT. 2023-03-15
  • Supported timm. 2023-03-16
  • Supported multi-lora. 2023-11-15
  • Repo clean up.

Installation

Gii clone. My torch.__version__==1.13.0, other version newer than torch.__version__==1.10.0 should also work, I guess. You may also need a safetensors from huggingface to load and save weight.

Examples

You may find examples in examples.ipynb

Usage

You may use Vision Transformer from timm:

import timm
import torch
from lora import LoRA_ViT_timm
img = torch.randn(2, 3, 224, 224)
model = timm.create_model('vit_base_patch16_224', pretrained=True)
lora_vit = LoRA_ViT_timm(vit_model=model, r=4, num_classes=10)
pred = lora_vit(img)
print(pred.shape)

If timm is too complicated, you can use a simpler implementation of ViT from lukemelas/PyTorch-Pretrained-ViT. Wrap you ViT using LoRA-ViT, this a simple example of classifer

from base_vit import ViT
import torch
from lora import LoRA_ViT

model = ViT('B_16_imagenet1k')
model.load_state_dict(torch.load('B_16_imagenet1k.pth'))
preds = model(img) # preds.shape = torch.Size([1, 1000])

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"trainable parameters: {num_params}") #trainable parameters: 86859496


lora_model = LoRA_ViT(model, r=4, num_classes=10)
num_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)
print(f"trainable parameters: {num_params}") # trainable parameters: 147456

this an example for segmentation tasks, using deeplabv3

model = ViT('B_16_imagenet1k')
model.load_state_dict(torch.load('B_16_imagenet1k.pth'))
lora_model = LoRA_ViT(model, r=4)
seg_lora_model = SegWrapForViT(vit_model=lora_model, image_size=384,
                            patches=16, dim=768, n_classes=10)

num_params = sum(p.numel() for p in seg_lora_model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params/2**20:.3f}") # trainable parameters: 6.459

Save and load LoRA:

lora_model.save_lora_parameters('mytask.lora.safetensors') # save
lora_model.load_lora_parameters('mytask.lora.safetensors') # load

Performance

In M1 Pro, LoRA is about 1.8x~1.9x faster. python performance_profile.py should do the time profiler now. More test will come soon.

Citation

Use this bibtex to cite this repository:

@misc{zhu2023melo,
      title={MeLo: Low-rank Adaptation is Better than Fine-tuning for Medical Image Diagnosis}, 
      author={Yitao Zhu and Zhenrong Shen and Zihao Zhao and Sheng Wang and Xin Wang and Xiangyu Zhao and Dinggang Shen and Qian Wang},
      year={2023},
      eprint={2311.08236},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Credit

ViT code and imagenet pretrained weight come from lukemelas/PyTorch-Pretrained-ViT

More Repositories

1

PointRend

an numpy-based implement of PointRend
Jupyter Notebook
205
star
2

Sam_LoRA

Low rank adaptation for segmentation anything model (SAM)
Python
205
star
3

Bloody_pressure_monitor

一个基于Arduino的血压计的开源项目。An open source sphygmomanometer develped on Arduino.
C++
55
star
4

CVFPaperHelper

Automatically download multiple papers by keywords in CVPR
Python
52
star
5

simpleITK-Snap

A qt-based 3D data visualization tool.
Python
51
star
6

MicEye

Record radiologists' eye gaze when they are labeling images.
Python
44
star
7

vit3d-pytorch

3D Vision Transformer, in PyTorch
Python
35
star
8

PyEyetracker

A python interface for the Tobii Eye Tracker
Python
19
star
9

VoxelRend

A modified pointrend for 3D medical image
Python
8
star
10

MICCAI-AUTHOR-STATS

In this project, I did a survey with all the papers in MICCAI 2010-2018 to find relation between scholars.
Jupyter Notebook
8
star
11

TenGigaRays

A very simple C++ ray tracing render.
C++
6
star
12

contrastive_learning_in_100_lines

A simple and intuitive contrastive learning implementation
Python
6
star
13

Pod

python oxford dictionary
Python
3
star
14

WTF_RequirementCheck

A tool check the requirement of TF in windows
Python
3
star
15

OpenEyeTracker

An open source solution for eye tracker hardware and software.
2
star
16

VNet-PyTorch

A flexible, elegant and parallel V-Net implement in PyTorch.
2
star
17

MeLo

Medical image Low-rank adaptation for diagnosis, segmentation and generation.
JavaScript
2
star
18

SA-INR

We propose a single super-resolution network for continuous reduction of MR slice spacing
Python
2
star
19

JamesQFreeman

2
star
20

jamesqfreeman.github.io

1
star