• Stars
    star
    314
  • Rank 132,582 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Official PyTorch implementation of "ML-Decoder: Scalable and Versatile Classification Head" (2021)

ML-Decoder: Scalable and Versatile Classification Head

PWC
PWC
PWC
PWC
PWC


Paper | Pretrained Models | Datasets

Official PyTorch Implementation

Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben-Baruch, Asaf Noy
DAMO Academy, Alibaba Group

Abstract

In this paper, we introduce ML-Decoder, a new attention-based classification head. ML-Decoder predicts the existence of class labels via queries, and enables better utilization of spatial data compared to global average pooling. By redesigning the decoder architecture, and using a novel group-decoding scheme, ML-Decoder is highly efficient, and can scale well to thousands of classes. Compared to using a larger backbone, ML-Decoder consistently provides a better speed-accuracy trade-off. ML-Decoder is also versatile - it can be used as a drop-in replacement for various classification heads, and generalize to unseen classes when operated with word queries. Novel query augmentations further improve its generalization ability. Using ML-Decoder, we achieve state-of-the-art results on several classification tasks: on MS-COCO multi-label, we reach 91.4% mAP; on NUS-WIDE zero-shot, we reach 31.1% ZSL mAP; and on ImageNet single-label, we reach with vanilla ResNet50 backbone a new top score of 80.7%, without extra data or distillation.

ML-Decoder Implementation

ML-Decoder implementation is available here. It can be easily integrated into any backbone using this example code:

ml_decoder_head = MLDecoder(num_classes) # initilization

spatial_embeddings = self.backbone(input_image) # backbone generates spatial embeddings      
 
logits = ml_decoder_head(spatial_embeddings) # transfrom spatial embeddings to logits

Inference Code and Pretrained Models

See Model Zoo

Training Code

We share a full reproduction code for the article results.

Multi-label Training Code


A reproduction code for MS-COCO multi-label:

python train.py  \
--data=/home/datasets/coco2014/ \
--model_name=tresnet_l \
--image_size=448

Single-label Training Code

Our single-label training code uses the excellent timm repo. Reproduction code is currently from a fork, we will work toward a full merge to the main repo.

git clone https://github.com/mrT23/pytorch-image-models.git

This is the code for A2 configuration training, with ML-Decoder (--use-ml-decoder-head=1):

python -u -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
./train.py \
/data/imagenet/ \
--amp \
-b=256 \
--epochs=300 \
--drop-path=0.05 \
--opt=lamb \
--weight-decay=0.02 \
--sched='cosine' \
--lr=4e-3 \
--warmup-epochs=5 \
--model=resnet50 \
--aa=rand-m7-mstd0.5-inc1 \
--reprob=0.0 \
--remode='pixel' \
--mixup=0.1 \
--cutmix=1.0 \
--aug-repeats 3 \
--bce-target-thresh 0.2 \
--smoothing=0 \
--bce-loss \
--train-interpolation=bicubic \
--use-ml-decoder-head=1

ZSL Training Code


First download the following files to the root path of the dataset:

benchmark_81_v0.json
wordvec_array.pth
data.csv

Training code for NUS-WIDE ZSL:

python train_zsl_nus.py  \
--data=/home/datasets/nus_wide/ \
--image_size=224

New Top Results - Stanford-Cars and CIFAR 100

Using ML-Decoder classification head, we reached a top result of 96.41% on Stanford-Cars dataset, and 95.1% on CIFAR-100 dataset. We will add this result to a future version of the paper.

Citation

@misc{ridnik2021mldecoder,
      title={ML-Decoder: Scalable and Versatile Classification Head}, 
      author={Tal Ridnik and Gilad Sharir and Avi Ben-Cohen and Emanuel Ben-Baruch and Asaf Noy},
      year={2021},
      eprint={2111.12933},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

More Repositories

1

ASL

Official Pytorch Implementation of: "Asymmetric Loss For Multi-Label Classification"(ICCV, 2021) paper
Python
715
star
2

ImageNet21K

Official Pytorch Implementation of: "ImageNet-21K Pretraining for the Masses"(NeurIPS, 2021) paper
Python
713
star
3

TResNet

Official Pytorch Implementation of "TResNet: High-Performance GPU-Dedicated Architecture" (WACV 2021)
Python
465
star
4

STAM

Official implementation of "An Image is Worth 16x16 Words, What is a Video Worth?" (2021 paper)
Python
219
star
5

Solving_ImageNet

Official PyTorch implementation of the paper: "Solving ImageNet: a Unified Scheme for Training any Backbone to Top Results" (2022)
Python
190
star
6

PartialLabelingCSL

Official implementation for the paper: "Multi-label Classification with Partial Annotations using Class-aware Selective Loss"
Python
127
star
7

AudioClassfication

Python
71
star
8

HardCoReNAS

Python
34
star
9

HeadSharingKD

Implementation of the paper "It's All in the Head: Representation Knowledge Distillation through Classifier Sharing"
Python
34
star
10

ZS_SDL

Official Pytorch Implementation of: "Semantic Diversity Learning for Zero-Shot Multi-label Classification"(ICCV, 2021) paper
Python
28
star
11

PETA

Official Pytorch Implementation of "PETA: Photo Albums Event Recognition using Transformers Attention" (2021)
Python
18
star
12

CobBO

Coordinate Backoff Bayesian Optimization
Python
8
star
13

alibaba-miil.github.io

Curated list of miil papers
7
star
14

knapsack_pruning

Python
3
star
15

BINAS

Constructing interpretable bilinear accuracy predictors to serve as an objective function for an IQCQP problem that represents NAS under latency constraints and solve it with efficient algorithms.
Python
3
star