CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances
Official PyTorch implementation of "CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances" (NeurIPS 2020) by Jihoon Tack*, Sangwoo Mo*, Jongheon Jeong, and Jinwoo Shin.
1. Requirements
Environments
Currently, requires following packages
- python 3.6+
- torch 1.4+
- torchvision 0.5+
- CUDA 10.1+
- scikit-learn 0.22+
- tensorboard 2.0+
- torchlars == 0.1.2
- pytorch-gradual-warmup-lr packages
- apex == 0.1
- diffdist == 0.1
Datasets
For CIFAR, please download the following datasets to ~/data
.
For ImageNet-30, please download the following datasets to ~/data
.
- ImageNet-30-train, ImageNet-30-test
- CUB-200, Stanford Dogs, Oxford Pets, Oxford flowers, Food-101, Places-365, Caltech-256, DTD
For Food-101, remove hotdog class to avoid overlap.
2. Training
Currently, all code examples are assuming distributed launch with 4 multi GPUs.
To run the code with single GPU, remove -m torch.distributed.launch --nproc_per_node=4
.
Unlabeled one-class & multi-class
To train unlabeled one-class & multi-class models in the paper, run this command:
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode simclr_CSI --shift_trans_type rotation --batch_size 32 --one_class_idx <One-Class-Index>
Option --one_class_idx denotes the in-distribution of one-class training. For multi-class training, set --one_class_idx as None. To run SimCLR simply change --mode to simclr. Total batch size should be 512 = 4 (GPU) * 32 (--batch_size option) * 4 (cardinality of shifted transformation set).
Labeled multi-class
To train labeled multi-class model (confidence calibrated classifier) in the paper, run this command:
# Representation train
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode sup_simclr_CSI --shift_trans_type rotation --batch_size 32 --epoch 700
# Linear layer train
python train.py --mode sup_CSI_linear --dataset <DATASET> --model <NETWORK> --batch_size 32 --epoch 100 --shift_trans_type rotation --load_path <MODEL_PATH>
To run SupCLR simply change --mode to sup_simclr, sup_linear for representation training and linear layer training respectively. Total batch size should be same as above. Currently only supports rotation for shifted transformation.
3. Evaluation
We provide the checkpoint of the CSI pre-trained model. Download the checkpoint from the following link:
- One-class CIFAR-10: ResNet-18
- Unlabeled (multi-class) CIFAR-10: ResNet-18
- Unlabeled (multi-class) ImageNet-30: ResNet-18
- Labeled (multi-class) CIFAR-10: ResNet-18
Unlabeled one-class & multi-class
To evaluate my model on unlabeled one-class & multi-class out-of-distribution (OOD) detection setting, run this command:
python eval.py --mode ood_pre --dataset <DATASET> --model <NETWORK> --ood_score CSI --shift_trans_type rotation --print_score --ood_samples 10 --resize_factor 0.54 --resize_fix --one_class_idx <One-Class-Index> --load_path <MODEL_PATH>
Option --one_class_idx denotes the in-distribution of one-class evaluation. For multi-class evaluation, set --one_class_idx as None. The resize_factor & resize fix option fix the cropping size of RandomResizedCrop(). For SimCLR evaluation, change --ood_score to simclr.
Labeled multi-class
To evaluate my model on labeled multi-class accuracy, ECE, OOD detection setting, run this command:
# OOD AUROC
python eval.py --mode ood --ood_score baseline_marginalized --print_score --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
# Accuray & ECE
python eval.py --mode test_marginalized_acc --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
This option is for marginalized inference. For single inference (also used for SupCLR) change --ood_score baseline in first command, and --mode test_acc in second command.
4. Results
Our model achieves the following performance on:
One-Class Out-of-Distribution Detection
Method | Dataset | AUROC (Mean) |
---|---|---|
SimCLR | CIFAR-10-OC | 87.9% |
Rot+Trans | CIFAR-10-OC | 90.0% |
CSI (ours) | CIFAR-10-OC | 94.3% |
We only show CIFAR-10 one-class result in this repo. For other setting, please see our paper.
Unlabeled Multi-Class Out-of-Distribution Detection
Method | Dataset | OOD Dataset | AUROC (Mean) |
---|---|---|---|
Rot+Trans | CIFAR-10 | CIFAR-100 | 82.5% |
CSI (ours) | CIFAR-10 | CIFAR-100 | 89.3% |
We only show CIFAR-10 to CIFAR-100 OOD detection result in this repo. For other OOD dataset results, see our paper.
Labeled Multi-Class Result
Method | Dataset | OOD Dataset | Acc | ECE | AUROC (Mean) |
---|---|---|---|---|---|
SupCLR | CIFAR-10 | CIFAR-100 | 93.9% | 5.54% | 88.3% |
CSI (ours) | CIFAR-10 | CIFAR-100 | 94.8% | 4.24% | 90.6% |
CSI-ensem (ours) | CIFAR-10 | CIFAR-100 | 96.0% | 3.64% | 92.3% |
We only show CIFAR-10 with CIFAR-100 as OOD in this repo. For other dataset results, please see our paper.
5. New OOD dataset
We find that current benchmark datasets for OOD detection, are visually far from in-distribution datasets (e.g. CIFAR).
To address this issue, we provide new datasets for OOD detection evaluation: LSUN_fix, ImageNet_fix. See the above figure for the visualization of current benchmark and our dataset.
To generate OOD datasets, run the following codes inside the ./datasets
folder:
# ImageNet FIX generation code
python imagenet_fix_preprocess.py
# LSUN FIX generation code
python lsun_fix_preprocess.py
Citation
@inproceedings{tack2020csi,
title={CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances},
author={Jihoon Tack and Sangwoo Mo and Jongheon Jeong and Jinwoo Shin},
booktitle={Advances in Neural Information Processing Systems},
year={2020}
}