• Stars
    star
    229
  • Rank 173,666 (Top 4 %)
  • Language
    Python
  • Created about 4 years ago
  • Updated 12 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

UNet3+/ UNet++/UNet, used in Deep Automatic Portrait Matting in Pytorth

UNet3plus_pth

UNet3+/UNet++/UNet, used in Deep Automatic Portrait Matting in Pytorth

Dependencies

  • Python 3.6
  • PyTorch >= 1.1.0
  • Torchvision >= 0.3.0
  • future 0.18.2
  • matplotlib 3.1.3
  • numpy 1.16.0
  • Pillow 6.2.0
  • protobuf 3.11.3
  • tensorboard 1.14.0
  • tqdm==4.42.1

Data

This model was trained from scratch with 18000 images (data augmentation by 2000images) Training dataset was from Deep Automatic Portrait Matting. Your can download in baidu cloud http://pan.baidu.com/s/1dE14537. Password: ndg8 For academic communication only, if there is a quote, please inform the original author!

We augment the number of images by perturbing them withrotation and scaling. Four rotation angles{−45◦,−22◦,22◦,45◦}and four scales{0.6,0.8,1.2,1.5}are used. We also apply four different Gamma transforms toincrease color variation. The Gamma values are{0.5,0.8,1.2,1.5}. After thesetransforms, we have 18K training images.

Run locally

Note : Use Python 3

Training

> python train.py -h
usage: train.py [-h] [-g G] [-u U] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL]

Train the UNet on images and target masks

optional arguments:
  -h, --help            show this help message and exit
  -g G, --gpu_id        Number of gpu
  -u U, --unet\_type    UNet type is unet/unet2/unet3
  -e E, --epochs E      Number of epochs (default: 5)
  -b [B], --batch-size [B]
                        Batch size (default: 1)
  -l [LR], --learning-rate [LR]
                        Learning rate (default: 0.1)
  -f LOAD, --load LOAD  Load model from a .pth file (default: False)
  -s SCALE, --scale SCALE
                        Downscaling factor of the images (default: 0.5)
  -v VAL, --validation VAL
                        Percent of the data that is used as validation (0-100)
                        (default: 10.0)

By default, the scale is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.

The input images and target masks should be in the data/imgs and data/masks folders respectively.

Notes on memory

$ python train.py -g 0 -u v3 -e 200 -b 1 -l 0.1 -s 0.5 -v 15.0

Prediction

You can easily test the output masks on your images via the CLI.

To predict a single image and save it:

$ python predict.py -i image.jpg -o output.jpg

To predict a multiple images and show them without saving them:

$ python predict.py -i image1.jpg image2.jpg --viz --no-save
> python predict.py -h
usage: predict.py [-h] [--gpu_id 0]  [--unet\_type unet/unet2/unet3] [--model FILE] --input INPUT [INPUT ...] [--output INPUT [INPUT ...]] [--viz] [--no-save] [--mask-threshold MASK_THRESHOLD] [--scale SCALE]

Predict masks from input images

optional arguments:
  -h, --help            show this help message and exit
  -g G, --gpu_id        Number of gpu
  --unet\_type, -u U    UNet type is unet/unet2/unet3
  --model FILE, -m FILE
                        Specify the file in which the model is stored
                        (default: MODEL.pth)
  --input INPUT [INPUT ...], -i INPUT [INPUT ...]
                        filenames of input images (default: None)
  --output INPUT [INPUT ...], -o INPUT [INPUT ...]
                        Filenames of ouput images (default: None)
  --viz, -v             Visualize the images as they are processed (default:
                        False)
  --no-save, -n         Do not save the output masks (default: False)
  --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
                        Minimum probability value to consider a mask pixel
                        white (default: 0.5)
  --scale SCALE, -s SCALE
                        Scale factor for the input images (default: 0.5)

Reference

[2015] U-Net: Convolutional Networks for Biomedical Image Segmentation (MICCAI)

[2018] UNet++: A Nested U-Net Architecture for Medical Image Segmentation (MICCAI)

[2020] UNET 3+: A Full-Scale Connected UNet for Medical Image Segmentation (ICASSP 2020)

More Repositories

1

Yolov5_tf

Yolov5/Yolov4/ Yolov3/ Yolo_tiny in tensorflow
Python
289
star
2

MobilenetSSD_caffe

How to train and verify mobilenet by using voc pascal data in caffe ssd?
C++
25
star
3

FisheyeCamera

Fisheye Camera for fisheye calibration/undistort/unwrap/stitch to make panorama etc.
C++
11
star
4

SSD_detect_tensorflow

How to detect objects by using SSD models in tensorflow
Python
7
star
5

DL-NLP

DL-NLP is a tutorial for NLP(Natural Language Processing) based on DL(Deep Learning) by using Pytorch and Tensorflow.
Python
5
star
6

AI_IndustrialInspection

AI赋能工业,专注工业零部件、工业缺陷等识别检测,实时检测、高精度。
3
star
7

caffe_ssd-android-lib

How to compile caffe_ssd based on Android
C++
3
star
8

MSFFN

An MultiSpectral Feature Fusion Network (MSFFN) for object detection or pedestrian detection.
Python
3
star
9

FlowerDetect_tensorflow

How to train and verify flowers by using CNN based on tensorflow?
Python
3
star
10

OCR_CTPN-CRNN-CTCLoss

OCR used CTPN & CRNN model and CTC Loss
Python
3
star
11

AI_AquaticIndustry

AI赋能水产业,专注鱼苗、虾苗等水产品识别检测,实时检测、高精度。
2
star
12

MagicInfer

Self-developed deep learning reasoning framework, with basic operators and rudiments, and equipped with Yolo Demo, is for reference only!
C++
2
star
13

VoAI

All things about VoAI Community organizations
2
star
14

Mnist_tensorflow

How to train and test mnist data by using CNN in tensorflow
Python
1
star
15

Classifier_halcon

A classifier based on Halcon
1
star
16

Yolov4_pth

Yolov4 for Pytorch
Python
1
star
17

Cifar10_tensorflow

How to train and test cifar10 data by using CNN in tensorflow
Python
1
star
18

Travel4Android

How to study android
1
star
19

miniCaffe

A mini caffe has only base train/val/test code, excluding cuda/cudnn/gpu and python/matlab etc.
C++
1
star
20

miniCaffe_ssd

A mini caffe_ssd has only base train/val/test code, excluding cuda/cudnn/gpu and python/matlab etc.
C++
1
star
21

CapsuleNetwork_Tensorflow

A Tensorflow implementation of Capsules Network in Hinton's paper of Dynamic Routing Between Capsules
1
star
22

VoAI-workspace

本人及VoAI工作室团队专注深度学习模型精度和性能提升、数据采集清洗,具有10年以上开发经验、5年图像处理和深度学习经验。 有需求者,可以联系Wechat:345238818、Tel13390758348、Email:[email protected]
1
star
23

DangerSceneDetect_tf2

A Dangerous Scene Detection included fire, accident and robbery etc. by using keras in Tensorflow2.
Python
1
star