TOIST: Task Oriented Instance Segmentation Transformer with Noun-Pronoun Distillation
This repository is an official implementation of TOIST:
TOIST: Task Oriented Instance Segmentation Transformer with Noun-Pronoun Distillation
Pengfei Li, Beiwen Tian, Yongliang Shi, Xiaoxue Chen, Hao Zhao, Guyue Zhou, Ya-Qin Zhang
In NeurIPS 2022
Introduction
Current referring expression comprehension algorithms can effectively detect or segment objects indicated by nouns, but how to understand verb reference is still under-explored. As such, we study the challenging problem of task oriented detection, which aims to find objects that best afford an action indicated by verbs like sit comfortably on. Towards a finer localization that better serves downstream applications like robot interaction, we extend the problem into task oriented instance segmentation. A unique requirement of this task is to select preferred candidates among possible alternatives. Thus we resort to the transformer architecture which naturally models pair-wise query relationships with attention, leading to the TOIST method. In order to leverage pre-trained noun referring expression comprehension models and the fact that we can access privileged noun ground truth during training, a novel noun-pronoun distillation framework is proposed. Noun prototypes are generated in an unsupervised manner and contextual pronoun features are trained to select prototypes. As such, the network remains noun-agnostic during inference. We evaluate TOIST on the large-scale task oriented dataset COCO-Tasks and achieve +10.9% higher
If you find our code or paper useful, please consider citing:
@article{li2022toist,
title={TOIST: Task Oriented Instance Segmentation Transformer with Noun-Pronoun Distillation},
author={Li, Pengfei and Tian, Beiwen and Shi, Yongliang and Chen, Xiaoxue and Zhao, Hao and Zhou, Guyue and Zhang, Ya-Qin},
journal={arXiv preprint arXiv:2210.10775},
year={2022}
}
This repository is a PyTorch implementation.
Datasets
Please follow the instructions in the official website to download the COCO-Tasks dataset.
You can organize the 'data' folder as follows:
data/
├── id2name.json
├── images/
│ ├── train2014/
│ └── val2014/
└── coco-tasks/
└── annotations/
├── task_1_train.json
├── task_1_test.json
...
├── task_14_train.json
└── task_14_test.json
Then set the arguments coco_path
, refexp_ann_path
and catid2name_path
in file configs/tdod.json
to be the path of data/images/
, data/coco-tasks/annotations/
and data/id2name.json
, respectively.
Installation
Make sure that you have all dependencies in place. The simplest way to do so is to use anaconda.
Make a new conda env and activate it:
conda create --name TOIST python=3.8
conda activate TOIST
Install the the packages in the requirements.txt:
pip install -r requirements.txt
Running
1. Plain TOIST detection
Training
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=6 --use_env main.py \
--dataset_config configs/tdod.json \
--train_batch_size 6 \
--valid_batch_size 8 \
--load /path/to/pretrained_resnet101_checkpoint.pth \
--ema --text_encoder_lr 1e-5 --lr 5e-5 \
--num_workers 5 \
--output-dir 'logs/test' \
--eval_skip 1
To leverage the pre-trained noun referring expression comprehension model, download the checkpoint from here (provided by MDETR) and change the value of --load
to be the path of the checkpoint.
Evaluation
Please change --resume
to the path of the trained model to be evaluated.
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=1 --use_env main.py \
--dataset_config configs/tdod.json \
--valid_batch_size 8 \
--num_workers 5 \
--resume /path/to/checkpoint \
--ema --eval \
--output-dir 'logs/test' \
--no_contrastive_align_loss
Verb-noun input
To train or evaluate the teacher TOIST model which leverages the privileged ground truth knowledge by taking verb-noun expression as text input, just set --verb_noun_input
like:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=6 --use_env main.py \
--dataset_config configs/tdod.json \
--train_batch_size 6 \
--valid_batch_size 8 \
--load /path/to/pretrained_resnet101_checkpoint.pth \
--ema --text_encoder_lr 1e-5 --lr 5e-5 \
--num_workers 5 \
--output-dir 'logs/test' \
--eval_skip 1 \
--verb_noun_input
Running without pre-training
To train TOIST without using the pre-trained noun referring expression comprehension model, leave the parameter --load
empty and set --without_pretrain
.
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=6 --use_env main.py \
--dataset_config configs/tdod.json \
--train_batch_size 6 \
--valid_batch_size 8 \
--ema --text_encoder_lr 1e-5 --lr 5e-5 \
--num_workers 5 \
--output-dir 'logs/test' \
--eval_skip 1 \
--without_pretrain
For evaluation, just change --resume
and set --without_pretrain
in the aforementioned evaluation command.
2. Plain TOIST segmentation
After training the detection part of TOIST, using the following commands to train and evaluate the segment head of TOIST.
Training
Please change --frozen_weights
to the path of the trained detection model.
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=6 --use_env main.py \
--dataset_config configs/tdod.json \
--train_batch_size 2 \
--valid_batch_size 4 \
--frozen_weights /path/to/trained/detection/checkpoint \
--mask_model smallconv \
--no_aux_loss \
--ema --text_encoder_lr 1e-5 --lr 5e-5 \
--num_workers 5 \
--output-dir 'logs/test' \
--eval_skip 1 \
--no_contrastive_align_loss
Evaluation
Please change --resume
to the path of the trained model to be evaluated.
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=1 --use_env main.py \
--dataset_config configs/tdod.json \
--valid_batch_size 4 \
--num_workers 5 \
--resume /path/to/checkpoint \
--ema --eval \
--output-dir 'logs/test' \
--mask_model smallconv \
--no_contrastive_align_loss
3. TOIST detection with noun-pronoun distillation
Training
To train TOIST with distillation, change --load
to the path of the trained student model (taking verb-pronoun as text input) and --load_noun
to the path of the trained teacher model (taking verb-noun as text input).
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=6 --use_env main.py \
--dataset_config configs/tdod.json \
--train_batch_size 3 \
--valid_batch_size 8 \
--load /path/to/pronoun/detection/checkpoint \
--load_noun /path/to/noun/detection/checkpoint \
--ema --text_encoder_lr 1e-5 --lr 5e-5 \
--num_workers 5 \
--output-dir 'logs/test' \
--eval_skip 1 \
--distillation \
--softkd_loss \
--softkd_coef 50 \
--cluster \
--cluster_memory_size 1024 \
--cluster_num 3 \
--cluster_feature_loss 1e4
The parameters --cluster
, --cluster_memory_size
, --cluster_num
and --cluster_feature_loss
are used for Clustering Distillation. The parameters --softkd_loss
and --softkd_coef
are used for Preference Distillation.
Evaluation
Please change --resume
to the path of the trained model to be evaluated.
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=1 --use_env main.py \
--dataset_config configs/tdod.json \
--valid_batch_size 4 \
--num_workers 5 \
--resume /path/to/checkpoint \
--ema --eval \
--output-dir 'logs/test' \
--cluster \
--cluster_memory_size 1024 \
--cluster_num 3 \
--no_contrastive_align_loss \
--distillation
The parameters --cluster_memory_size
and --cluster_num
should be consistent with training setting.
4. TOIST segmentation with noun-pronoun distillation
Training
Please change --frozen_weights
to the path of the trained detection (with distillation) model.
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=6 --use_env main.py \
--dataset_config configs/tdod.json \
--train_batch_size 2 \
--valid_batch_size 4 \
--frozen_weights /path/to/trained/detection/with/distillation/checkpoint \
--mask_model smallconv \
--no_aux_loss \
--ema --text_encoder_lr 1e-5 --lr 5e-5 \
--num_workers 5 \
--output-dir 'logs/test' \
--eval_skip 1 \
--cluster \
--cluster_memory_size 1024 \
--cluster_num 3 \
--no_contrastive_align_loss
Evaluation
Please change --resume
to the path of the trained model to be evaluated.
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port=23456 --nproc_per_node=1 --use_env main.py \
--dataset_config configs/tdod.json \
--valid_batch_size 4 \
--num_workers 5 \
--resume /path/to/checkpoint \
--ema --eval \
--output-dir 'logs/test' \
--cluster \
--cluster_memory_size 1024 \
--cluster_num 3 \
--mask_model smallconv \
--no_contrastive_align_loss
Pre-trained Models
We provide our pretrained models on Google Drive.
Table/Figure No. | Row No. | Model Name | Checkpoint |
---|---|---|---|
Table 1 | 1 | verb-pronoun input | Google Drive |
2 | verb-noun input | Google Drive | |
5 | noun-pronoun distillation | Google Drive | |
Figure3 (a) | / | decoder w/o self attention | Google Drive |
Figure3 (b) | / | cluster number K=1 | Google Drive |
/ | cluster number K=2 | Google Drive | |
/ | cluster number K=5 | Google Drive | |
/ | cluster number K=7 | Google Drive | |
/ | cluster number K=10 | Google Drive | |
Table 3 | 2 | CCR/CL/SBTL=F/F/T | Google Drive |
3 | CCR/CL/SBTL=F/T/F | Google Drive | |
4 | CCR/CL/SBTL=F/T/T | Google Drive | |
5 | CCR/CL/SBTL=T/F/F | Google Drive | |
6 | CCR/CL/SBTL=T/F/T | Google Drive | |
7 | CCR/CL/SBTL=T/T/F | Google Drive | |
Table 5 | 1 | verb-pronoun input w/o pretraining | Google Drive |
2 | verb-noun input w/o pretraining | Google Drive | |
3 | noun-pronoun distillation w/o pretraining | Google Drive | |
Table 6 | 2 | it | Google Drive |
3 | them | Google Drive | |
4 | abcd | Google Drive | |
6 | it w/ distillation | Google Drive | |
7 | them w/ distillation | Google Drive | |
8 | abcd w/ distillation | Google Drive | |
Table 8 | 2 | first-in-first-out memory update | Google Drive |
License
TOIST is released under the MIT License.
Acknowledgment
We would like to thank the open-source data and code of COCO-Tasks, Microsoft COCO, GGNN, MDETR, DETR and Detectron2.