Clothes Segmentation using U2NET
This repo contains training code, inference code and pre-trained model for Cloths Parsing from human portrait.
Here clothes are parsed into 3 category: Upper body(red), Lower body(green) and Full body(yellow)
This model works well with any background and almost all poses. For more samples visit samples.md
Techinal details
-
U2NET : This project uses an amazing U2NET as a deep learning model. Instead of having 1 channel output from u2net for typical salient object detection task it outputs 4 channels each respresting upper body cloth, lower body cloth, fully body cloth and background. Only categorical cross-entropy loss is used for a given version of the checkpoint.
-
Dataset : U2net is trained on 45k images iMaterialist (Fashion) 2019 at FGVC6 dataset. To reduce complexity, I have clubbed the original 42 categories from dataset labels into 3 categories (upper body, lower body and full body). All images are resized into square
¯\_(ツ)_/¯
768 x 768 px for training. (This experiment was conducted with 768 px but around 384 px will work fine too if one is retraining on another dataset).
Training
- For training this project requires,
- Â PyTorch > 1.3.0
- Â tensorboardX
- Â gdown
- Download dataset from this link, extract all items.
- Set path of
train
folder which contains training images andtrain.csv
which is label csv file inoptions/base_options.py
- To port original u2net of all layer except last layer please run
python setup_model_weights.py
and it will generate weights after model surgey inprev_checkpoints
folder. - You can explore various options in
options/base_options.py
like checkpoint saving folder, logs folder etc. - For single gpu set
distributed = False
inoptions/base_options.py
, for multi gpu set it toTrue
. - For single gpu run
python train.py
- For multi gpu run
Âpython -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=4 --use_env train.py
Here command is for single node, 4 gpu. Tested only for single node. - You can watch loss graphs and samples in tensorboard by running tensorboard command in log folder.
Testing/Inference
- Download pretrained model from this link(165 MB) in
trained_checkpoint
folder. - Put input images in
input_images
folder - Run
python infer.py
for inference. - Output will be saved in
output_images
OR
Acknowledgements
- U2net model is from original u2net repo. Thanks to Xuebin Qin for amazing repo.
- Complete repo follows structure of Pix2pixHD repo