• Stars
    star
    265
  • Rank 153,802 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created about 4 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

NeurIPS 2020, Debiased Contrastive Learning

Debiased Contrastive Learning

A prominent technique for self-supervised representation learning has been to contrast semantically similar and dissimilar pairs of samples. Without access to labels, dissimilar (negative) points are typically taken to be randomly sampled datapoints, implicitly accepting that these points may, in reality, actually have the same label. Perhaps unsurprisingly, we observe that sampling negative examples from truly different labels improves performance, in a synthetic setting where labels are available. Motivated by this observation, we develop a debiased contrastive objective that corrects for the sampling of same-label datapoints, even without knowledge of the true labels.

Debiased Contrastive Learning NeurIPS 2020 [paper]
Ching-Yao Chuang, Joshua Robinson, Lin Yen-Chen, Antonio Torralba, and Stefanie Jegelka

Prerequisites

  • Python 3.7
  • PyTorch 1.3.1
  • PIL
  • OpenCV

Contrastive Representation Learning

We can train standard (biased) or debiased version (M=1) of SimCLR with main.py on STL10 dataset.

flags:

  • --debiased: use debiased objective (True) or standard objective (False)
  • --tau_plus: specify class probability
  • --batch_size: batch size for SimCLR

For instance, run the following command to train a debiased encoder.

python main.py --tau_plus = 0.1

*Due to the implementation of nn.DataParallel(), training with at most 2 GPUs gives the best result.

Linear evaluation

The model is evaluated by training a linear classifier after fixing the learned embedding.

path flags:

  • --model_path: specify the path to saved model
python linear.py --model_path results/model_400.pth

Pretrained Models

tau_plus Arch Latent Dim Batch Size Accuracy(%) Download
Biased tau_plus = 0.0 ResNet50 128 256 80.15 model
Debiased tau_plus = 0.05 ResNet50 128 256 81.85 model
Debiased tau_plus = 0.1 ResNet50 128 256 84.26 model

Citation

If you find this repo useful for your research, please consider citing the paper

@article{chuang2020debiased,
  title={Debiased contrastive learning},
  author={Chuang, Ching-Yao and Robinson, Joshua and Lin, Yen-Chen and Torralba, Antonio and Jegelka, Stefanie},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}

For any questions, please contact Ching-Yao Chuang ([email protected]).

Acknowledgements

Part of this code is inspired by leftthomas/SimCLR.

More Repositories

1

ggnn.pytorch

A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)
Python
444
star
2

awesome-vqa

Visual Q&A reading list
424
star
3

pytorch-REINFORCE

PyTorch Implementation of REINFORCE for both discrete & continuous control
Python
256
star
4

VQA-tensorflow

Tensorflow Implementation of Deeper LSTM+ normalized CNN for Visual Question Answering
Python
99
star
5

RINCE

CVPR 2022, Robust Contrastive Learning against Noisy Views
Python
79
star
6

Tensorboard2Seaborn

Plot Tensorflow Summary Event in a Beautiful Way 🌈
Python
66
star
7

VQG-tensorflow

Visual Question Generation in Tensorflow
Python
66
star
8

fair-mixup

ICLR 2021, Fair Mixup: Fairness via Interpolation
Python
55
star
9

photo-editing-tensorflow

Photo Optimizing Adversarial Net with Policy Gradient Method
Python
54
star
10

InfoOT

[ICML2023] InfoOT: Information Maximizing Optimal Transport
Python
40
star
11

TMD

NeurIPS 2022: Tree Mover’s Distance: Bridging Graph Metrics and Stability of Graph Neural Networks
Python
32
star
12

kV-Margin

NeurIPS 2021, Code for Measuring Generalization with Optimal Transport
Python
26
star
13

debias_vl

Code for Debiasing Vision-Language Models via Biased Prompts
Python
25
star
14

san-torch

Torch implementation for Stacked Attention Networks
Lua
25
star
15

estimating-generalization

ICML 2020, Estimating Generalization under Distribution Shifts via Domain-Invariant Representations
Python
21
star
16

VisionWorks

Basic Computer Vision Problem & Work
MATLAB
9
star
17

mdm

Code for "The Role of Embedding Complexity in Domain-invariant Representations"
Python
2
star
18

FTP-Proxy-Rate-Control

Simple FTP-proxy with rate control.
C
2
star