PyTorch implementation of UNet++ (Nested U-Net)
This repository contains code for a image segmentation model based on UNet++: A Nested U-Net Architecture for Medical Image Segmentation implemented in PyTorch.
[NEW] Add support for multi-class segmentation dataset.
[NEW] Add support for PyTorch 1.x.
Requirements
- PyTorch 1.x or 0.41
Installation
- Create an anaconda environment.
conda create -n=<env_name> python=3.6 anaconda
conda activate <env_name>
- Install PyTorch.
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
- Install pip packages.
pip install -r requirements.txt
2018 Data Science Bowl dataset
Training on- Download dataset from here to inputs/ and unzip. The file structure is the following:
inputs
βββ data-science-bowl-2018
βββ stage1_train
| βββ 00ae65...
β β βββ images
β β β βββ 00ae65...
β β βββ masks
β β βββ 00ae65...
β βββ ...
|
...
- Preprocess.
python preprocess_dsb2018.py
- Train the model.
python train.py --dataset dsb2018_96 --arch NestedUNet
- Evaluate.
python val.py --name dsb2018_96_NestedUNet_woDS
(Optional) Using LovaszHingeLoss
- Clone LovaszSoftmax from bermanmaxim/LovaszSoftmax.
git clone https://github.com/bermanmaxim/LovaszSoftmax.git
- Train the model with LovaszHingeLoss.
python train.py --dataset dsb2018_96 --arch NestedUNet --loss LovaszHingeLoss
Training on original dataset
Make sure to put the files as the following structure (e.g. the number of classes is 2):
inputs
βββ <dataset name>
βββ images
| βββ 0a7e06.jpg
β βββ 0aab0a.jpg
β βββ 0b1761.jpg
β βββ ...
|
βββ masks
βββ 0
| βββ 0a7e06.png
| βββ 0aab0a.png
| βββ 0b1761.png
| βββ ...
|
βββ 1
βββ 0a7e06.png
βββ 0aab0a.png
βββ 0b1761.png
βββ ...
- Train the model.
python train.py --dataset <dataset name> --arch NestedUNet --img_ext .jpg --mask_ext .png
- Evaluate.
python val.py --name <dataset name>_NestedUNet_woDS
Results
DSB2018 (96x96)
Here is the results on DSB2018 dataset (96x96) with LovaszHingeLoss.
Model | IoU | Loss |
---|---|---|
U-Net | 0.839 | 0.365 |
Nested U-Net | 0.842 | 0.354 |
Nested U-Net w/ Deepsupervision | 0.843 | 0.362 |