• Stars
    star
    110
  • Rank 316,770 (Top 7 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated 9 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[ICML2022] Contrastive Learning with Boosted Memorization

Contrastive Learning with Boosted Memorization

Keywords: Long-Tailed Recognition, Self-Supervised Learning, Memorization Effect

Contrastive Learning with Boosted Memorization

ICML 2022

@inproceedings{zhou2022contrastive,
  title={Contrastive Learning with Boosted Memorization},
  author={Zhou, Zhihan and Yao, Jiangchao and Wang, Yan-Feng and Han, Bo and Zhang, Ya},
  booktitle={International Conference on Machine Learning},
  pages={27367--27377},
  year={2022},
  organization={PMLR}
}

Abstract: Self-supervised learning has achieved a great success in the representation learning of visual and textual data. However, the current methods are mainly validated on the well-curated datasets, which do not exhibit the real-world long-tailed distribution. Recent attempts to consider self-supervised long-tailed learning are made by rebalancing in the loss perspective or the model perspective, resembling the paradigms in the supervised long-tailed learning. Nevertheless, without the aid of labels, these explorations have not shown the expected significant promise due to the limitation in tail sample discovery or the heuristic structure design. Different from previous works, we explore this direction from an alternative perspective, i.e., the data perspective, and propose a novel Boosted Contrastive Learning (BCL) method. Specifically, BCL leverages the memorization effect of deep neural networks to automatically drive the information discrepancy of the sample views in contrastive learning, which is more efficient to enhance the long-tailed learning in the label-unaware context. Extensive experiments on a range of benchmark datasets demonstrate the effectiveness of BCL over several state-of-the-art methods.

Get Started

Environment

  • Python (3.7.10)
  • Pytorch (1.7.1)
  • torchvision (0.8.2)
  • CUDA
  • Numpy

File Structure

After the preparation work, the whole project should have the following structure:

./Boosted-Contrastive-Learning
โ”œโ”€โ”€ README.md
โ”œโ”€โ”€ data                            # datasets and augmentations
โ”‚   โ”œโ”€โ”€ memoboosted_cifar100.py
โ”‚   โ”œโ”€โ”€ cifar100.py                   
โ”‚   โ”œโ”€โ”€ augmentations.py
โ”‚   โ””โ”€โ”€ randaug.py
โ”œโ”€โ”€ models                          # models and backbones
โ”‚   โ”œโ”€โ”€ simclr.py
โ”‚   โ”œโ”€โ”€ sdclr.py
โ”‚   โ”œโ”€โ”€ resnet.py
โ”‚   โ”œโ”€โ”€ resnet_prune_multibn.py
โ”‚   โ””โ”€โ”€ utils.py
โ”œโ”€โ”€ losses                          # losses
โ”‚   โ””โ”€โ”€ nt_xent.py   
โ”œโ”€โ”€ split                           # data split
โ”‚   โ”œโ”€โ”€ cifar100                        
โ”‚   โ””โ”€โ”€ cifar100_imbSub_with_subsets
โ”œโ”€โ”€ eval_cifar.py                   # linear probing evaluation code
โ”œโ”€โ”€ test.py                         # testing code
โ”œโ”€โ”€ train.py                        # training code
โ”œโ”€โ”€ train_sdclr.py                  # training code for sdclr
โ””โ”€โ”€ utils.py                        # utils

Quick Preview

A code snippet of the BCL is shown below.

train_datasets = memoboosted_CIFAR100(train_idx_list, args, root=args.data_folder, train=True)

# initialize momentum loss
shadow = torch.zeros(dataset_total_num).cuda()
momentum_loss = torch.zeros(args.epochs,dataset_total_num).cuda()

shadow, momentum_loss = train(train_loader, model, optimizer, scheduler, epoch, log, shadow, momentum_loss, args=args)
train_datasets.update_momentum_weight(momentum_loss, epoch)

During the training phase, track the momentum loss.

if epoch>1:
    new_average = (1.0 - args.momentum_loss_beta) * loss[batch_idx].clone().detach() + args.momentum_loss_beta * shadow[index[batch_idx]]
else:
    new_average = loss[batch_idx].clone().detach()
    
shadow[index[batch_idx]] = new_average
momentum_loss[epoch-1,index[batch_idx]] = new_average

Training

To train model on CIFAR-100-LT, simply run:

  • SimCLR
python train.py SimCLR --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 5e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy 
  • BCL-I
python train.py BCL_I --bcl --rand_k 1 --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 5e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy 
  • SDCLR
python train_sdclr.py SDCLR --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 1e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy 
  • BCL-D
python train_sdclr.py BCL_D --bcl --rand_k 2 --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 1e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy 

Pretrained checkpoints will be saved in 'checkpoints/'.

Evaluating

To evalutate the pretrained model, simply run:

  • SimCLR, BCL-I
python test.py --checkpoint ${checkpoint_pretrain} --test_fullshot --test_100shot --test_50shot --data_folder ${data_folder}
  • SDCLR, BCL-D
python test.py --checkpoint ${checkpoint_pretrain} --prune --test_fullshot --test_100shot --test_50shot --data_folder ${data_folder}

The code will output the results of full-shot/100-shot/50-shot linear probing evaluation.

Results and Pretrained Models

We provide the full-shot/100-shot/50-shot results(demo) pretrained on 'cifar100_split1_D_i.npy' with the corresponding checkpoint weights.

Method Full-shot 100-shot 50-shot Model
SimCLR 50.7 46.3 42.4 ResNet18
SDCLR 55.0 49.7 45.6 ResNet18
BCL-I 55.7 50.1 45.8 ResNet18
BCL-D 58.7 52.6 48.7 ResNet18

After downloading the checkpoints, you could run evaluation by the instructions in the evaluating section.

Extensions

Steps to Implement Your Own Model

  • Add your model to ./models and load the model in train.py.
  • Implement functions(./losses) specfic to your models in train.py.

Steps to Implement Other Datasets

  • Create long-tailed splits of the datasets and add to ./split.
  • Implement the dataset (e.g. memoboosted_cifar100.py).

Acknowledgement

We borrow some codes from SDCLR, RandAugment and W-MSE.

More Repositories

1

MING

ๆ˜ŽๅŒป (MING)๏ผšไธญๆ–‡ๅŒป็–—้—ฎ่ฏŠๅคงๆจกๅž‹
Python
812
star
2

RegAD

[ECCV2022 Oral] Registration based Few-Shot Anomaly Detection
Python
268
star
3

FACT

Python
155
star
4

Where2comm

Python
147
star
5

MedKLIP

The official code for MedKLIP: Medical Knowledge Enhanced Language-Image Pre-Training in Radiology. We propose to leverage medical specific knowledge enhancing language-image pre-training method, significantly advancing the ability of pre-trained models to handle unseen diseases on zero-shot classification and grounding tasks.
Python
134
star
6

LED

[CVPR2023] Leapfrog Diffusion Model for Stochastic Trajectory Prediction
Jupyter Notebook
130
star
7

MVFA-AD

[CVPR2024 Highlight] Adapting Visual-Language Models for Generalizable Anomaly Detection in Medical Images
Python
129
star
8

MemoNet

[CVPR2022] Remember Intentions: Retrospective-Memory-based Trajectory Prediction
Python
118
star
9

EqMotion

[CVPR2023] EqMotion: Equivariant Multi-agent Motion Prediction with Invariant Interaction Reasoning
Python
112
star
10

GroupNet

[CVPR22] GroupNet: Multiscale Hypergraph Neural Networks for Trajectory Prediction with Relational Reasoning
Python
108
star
11

TBP-Former

78
star
12

CoCa3D

Python
75
star
13

GenMedicalEval

69
star
14

CoBEVFlow

[NeurIPS 2023] Asynchrony-Robust Collaborative Perception via Birdโ€™s Eye View Flow
Python
65
star
15

FedDisco

Python
60
star
16

RECORDS-LTPLL

[ICLR 2023] PyTorch implementation for "Long-Tailed Partial Label Learning via Dynamic Rebalancing"
Python
55
star
17

ECGAD

[MICCAI2023 Early Accept] Multi-scale Cross-restoration Framework for Electrocardiogram Anomaly Detection
Python
48
star
18

FedDG-GA

[CVPR 2023] Federated Domain Generalization with Generalization Adjustment
Python
37
star
19

SyncNet

[ECCV2022] Latency-Aware Collaborative Perception
Python
33
star
20

SPGSN

The source codes of 'Skeleton-parted graph scattering networks for 3D human motion prediction'. ECCV 2022
Python
29
star
21

AuxFormer

[ICCV2023] Auxiliary Tasks Benefit 3D Skeleton-based Human Motion Prediction
Python
25
star
22

pFedGraph

Python
23
star
23

BE-SSL

Codes for our paper "Boundary-Enhanced Self-Supervised Learningfor Brain Structure Segmentation"
Python
23
star
24

JRTransformer

[ICCV2023] Joint-Relation Transformer for Multi-Person Motion Prediction
Python
22
star
25

Geometric-Harmonization

[NeurIPS 2023 Spotlight] Combating Representation Learning Disparity with Geometric Harmonization
Python
19
star
26

Collaborative-Uncertainty

Python
19
star
27

GPFL-GRACE

[MICCAI 2023] GRACE: Enhancing Federated Learning for Medical Imaging with Generalized and Personalized Gradient Correction
Python
15
star
28

LoRKD

Python
12
star
29

K-Diag

Python
10
star
30

MoLA

Python
10
star
31

CoFormer

Python
10
star
32

FedGELA

[NeurIPS 2023]Federated Learning with Bilateral Curation for Partially Class-Disjoint Data
Python
10
star
33

FedLESAM

[ICML 2024 spotlight] This repository contains the implementation details for the paper "Locally Estimated Global Perturbations are Better than Local Perturbations for Federated Sharpness-aware Minimization"
Python
9
star
34

FedSkip

FedSkip-Combatting-Statistical-Heterogeneity-with-Federated-Skip-Aggregation official code
Python
7
star
35

OC_LT

Official code base for "Long-Tailed Diffusion Models With Oriented Calibration" ICLR2024
Python
6
star
36

DISAM

This repository contains the implementation details for the paper "Domain-Inspired Sharpness-Aware Minimization Under Domain Shifts," accepted at the ICLR 2024.
Python
6
star
37

CaT

[ICCV2021] CaT: Weakly Supervised Object Detection with Category Transfer
5
star
38

ECISQA

[NeurIPS 2023] Emergent communication in interactive sketch question answering
Jupyter Notebook
5
star
39

FreeAlign

Python
5
star
40

FedMR

[TMLR 2023]Federated Learning under Partially Class-Disjoint Data via Manifold Reshaping
Python
4
star
41

GSC

Python
4
star
42

SSM

[TMM 2022] Self-Supervised Masking for Unsupervised Anomaly Detection and Localization
Python
4
star
43

ITES

Python
1
star
44

NMMP

Python
1
star
45

mediabrain-sjtu.github.io

TeX
1
star