• Stars
    star
    268
  • Rank 152,291 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 5 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

This is the code for the NeurIPS 2019 paper Region Mutual Information Loss for Semantic Segmentation.

Region Mutual Information Loss for Semantic Segmentation

Table of Contents

Introduction

This is the code for the NeurIPS 2019 paper Region Mutual Information Loss for Semantic Segmentation.

This paper proposes a region mutual information (RMI) loss to model the dependencies among pixels. RMI uses one pixel and its neighbor pixels to represent this pixel. Then for each pixel in an image, we get a multi-dimensional point that encodes the relationship between pixels, and the image is cast into a multi-dimensional distribution of these high-dimensional points. The prediction and ground truth thus can achieve high order consistency through maximizing the mutual information (MI) between their multi-dimensional distributions.

img_intro

Features and TODO

  • Support different segmentation models, i.e., DeepLabv3, DeepLabv3+, PSPNet
  • Multi-GPU training
  • Multi-GPU Synchronized BatchNorm
  • Support different backbones, e.g., Mobilenet, Xception
  • Model pretrained on MS-COCO
  • Distributed training

We are open to pull requests.

Installation

Install dependencies

Please install PyTorch-1.1.0 and Python3.6.5. We highly recommend you to use our established PyTorch docker image - zhaosssss/torch_lab.

docker pull zhaosssss/torch_lab:1.1.0

If you have not installed docker, see https://docs.docker.com/.

After you install docker and pull our image, you can cd to script directory and run

./docker.sh

to create a running docker container.

If you do not want to use docker, try

pip install -r requirements.txt

However, this is not suggested.

Prepare data

Generally, directories are organized as follow:

|
|--dataset (save the dataset) 
|--models  (save the output checkpoints)
|--github  (save the code)
|--|
|--|--RMI  (the RMI code repository)
|--|--|--crf
|--|--|--dataloaders
|--|--|--losses
...

As for the CamVid dataset, you can download at SegNet-Tutorial. This is a processed version of original CamVid dataset.

Training

See script/train.sh for detailed information. Before start training, you should specify some variables in the script/train.sh.

  • pre_dir, where you save your output checkpoints. If you organize the dir as we suggest, it should be pre_dir=models.

  • data_dir, where you save your dataset. Besides, you should put the lists of the images in the dataset in a certain directory, check dataloaders/datasets/pascal.py to find how we organize the input pipeline.

You can find more information about the arguments of the code in parser_params.py.

python parser_params.py --help

usage: parser_params.py [-h] [--resume RESUME] [--checkname CHECKNAME]
                        [--save_ckpt_steps SAVE_CKPT_STEPS]
                        [--max_ckpt_nums MAX_CKPT_NUMS]
                        [--model_dir MODEL_DIR] [--output_dir OUTPUT_DIR]
                        [--seg_model {deeplabv3,deeplabv3+,pspnet}]
                        [--backbone {resnet50,resnet101,resnet152,resnet50_beta,resnet101_beta,resnet152_beta}]
                        [--out_stride OUT_STRIDE] [--batch_size N]
                        [--accumulation_steps N] [--test_batch_size N]
                        [--dataset {pascal,coco,cityscapes,camvid}]
                        [--train_split {train,trainaug,trainval,val,test}]
                        [--data_dir DATA_DIR] [--use_sbd] [--workers N]
                        ...
                        [--rmi_pool_size RMI_POOL_SIZE]
                        [--rmi_pool_stride RMI_POOL_STRIDE]
                        [--rmi_radius RMI_RADIUS]
                        [--crf_iter_steps CRF_ITER_STEPS]
                        [--local_rank LOCAL_RANK] [--world_size WORLD_SIZE]
                        [--dist_backend DIST_BACKEND]
                        [--multiprocessing_distributed]

After you set all the arguments properly, you can simply cd to RMI/script and run

./train.sh

to start training.

  • Monitoring the training process through tensorboard
tensorboard --logdir=your_logdir --port=your_port

img_ten

  • GPU memory usage

Training a DeepLabv3 model with output_stride=16, crop_size=513, and batch_size=16 needs 4 GTX 1080 GPUs (8GB) or 2 GTX TITAN X GPUs (12 GB) or 1 TITAN RTX GPUs (24 GB).

Evaluation and Inference

See script/eval.sh and script/inference.sh for detailed information.

You should also specify some variables in the scripts.

  • data_dir, where you save your dataset.

  • resume, where your checkpoints locate.

  • output_dir, where the output data will be saved.

Then run

./eval.sh

or

./inference.sh

Experiments

img_res01 img_res02

img_res03

Some selected qualitative results on PASCAL VOC 2012 val set. Segmentation results of DeepLabv3+&RMI have richer details than DeepLabv3+&CE, e.g., small bumps of the airplane wing, branches of plants, limbs of cows and sheep, and so on.

Citations

If our paper and code are beneficial to your work, please cite:

@inproceedings{2019_zhao_rmi,
  author    = {Shuai Zhao and
               Yang Wang and
               Zheng Yang and
               Deng Cai},
  title     = {Region Mutual Information Loss for Semantic Segmentation},
  booktitle = {NeurIPS},
  year      = {2019},
}

If other related work in our code or paper also helps you, please cite the corresponding papers.

Acknowledgements

img_cad

More Repositories

1

pixel_link

Implementation of our paper 'PixelLink: Detecting Scene Text via Instance Segmentation' in AAAI2018
Python
767
star
2

nsg

Navigating Spreading-out Graph For Approximate Nearest Neighbor Search
C++
584
star
3

MatlabFunc

Matlab codes for feature learning
MATLAB
502
star
4

ttfnet

Python
481
star
5

efanna

fast library for ANN search and KNN graph construction
C++
280
star
6

resa

Implementation of our paper 'RESA: Recurrent Feature-Shift Aggregator for Lane Detection' in AAAI2021.
Python
175
star
7

time_lstm

Python
152
star
8

MaxSquareLoss

Code for "Domain Adaptation for Semantic Segmentation with Maximum Squares Loss" in PyTorch.
Python
109
star
9

SSG

code for satellite system graphs
C++
95
star
10

efanna_graph

an Extremely Fast Approximate Nearest Neighbor graph construction Algorithm framework
C++
79
star
11

graph_level_drug_discovery

Python
60
star
12

CariFaceParsing

Code for ICIP2019 paper๏ผšWeakly-supervised Caricature Face Parsing through Domain Adaptation
Python
55
star
13

AtSNE

Anchor-t-SNE for large-scale and high-dimension vector visualization
Cuda
54
star
14

ALDA

Code for "Adversarial-Learned Loss for Domain Adaptation"(AAAI2020) in PyTorch.
Python
49
star
15

depthInpainting

Depth Image Inpainting with Low Gradient Regularization
C++
48
star
16

AttentionZSL

Codes for Paper "Attribute Attention for Semantic Disambiguation in Zero-Shot Learning"
Python
44
star
17

ReDR

Code for ACL 2019 paper "Reinforced Dynamic Reasoning for Conversational Question Generation".
Python
41
star
18

hashingSearch

Search with a hash index
C++
31
star
19

SRDet

A simple, fast, efficient and end-to-end 3D object detector without NMS.
Python
30
star
20

PTL

Progressive Transfer Learning for Person Re-identification published on IJCAI-2019
Python
26
star
21

TreeAttention

A Better Way to Attend: Attention with Trees for Video Question Answering
Python
24
star
22

RPLSH

Kmeans Quantization + Random Projection based Locality Sensitive Hashing
C++
23
star
23

videoqa

Unifying the Video and Question Attentions for Open-Ended Video Question Answering
Python
21
star
24

DMP

Code for ACL 2018 paper "Discourse Marker Augmented Network with Reinforcement Learning for Natural Language Inference".
Python
17
star
25

DREN

DREN:Deep Rotation Equivirant Network
C++
15
star
26

Attention-GRU-3M

Python
13
star
27

AMI

Python
7
star
28

Sparse-Learning-with-Stochastic-Composite-Optimization

The implementation of our work "Sparse Learning with Stochastic Composite Optimization"
MATLAB
7
star
29

TransAt

Python
6
star
30

diverse_image_synthesis

PyTorch implementation of diverse conditional image synthesis
Python
4
star
31

DeAda

Decouple Co-adaptation: Classifier Randomization for Person Re-identification published on Neurocomputing.
Python
3
star
32

AdaDB

Python
2
star
33

SIF

SIF: Self-Inspirited Feature Learning for Person Re-Identification published on IEEE TIP
Python
2
star
34

SIFS

C++
1
star
35

SplitNet

Jupyter Notebook
1
star