• Stars
    star
    121
  • Rank 293,924 (Top 6 %)
  • Language
    Jupyter Notebook
  • License
    Other
  • Created about 2 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[ICLR 2023, Oral] SimPer: Simple Self-Supervised Learning of Periodic Targets

SimPer: Simple Self-Supervised Learning of Periodic Targets

This repository contains the implementation code for paper:
SimPer: Simple Self-Supervised Learning of Periodic Targets
Yuzhe Yang, Xin Liu, Jiang Wu, Silviu Borac, Dina Katabi, Ming-Zher Poh, Daniel McDuff
11th International Conference on Learning Representations (ICLR 2023), Notable-Top-5% & Oral
[Project Page] [Paper] [Video] [Blog Post]

If you find this code or idea useful, please consider citing our work:

@inproceedings{yang2023simper,
  title={SimPer: Simple Self-Supervised Learning of Periodic Targets},
  author={Yang, Yuzhe and Liu, Xin and Wu, Jiang and Borac, Silviu and Katabi, Dina and Poh, Ming-Zher and McDuff, Daniel},
  booktitle={International Conference on Learning Representations},
  year={2023},
  url={https://openreview.net/forum?id=EKpMeEV0hOo}
}


SimPer learns robust periodic representations with high frequency resolution.

Updates

  • [07/2023] SimPer is featured on the Google AI Blog.
  • [07/2023] We provide a hands-on tutorial of SimPer. Check it out!
  • [06/2023] Check out the Oral talk video (15 mins) for our paper.
  • [02/2023] Paper accepted to ICLR 2023 as Notable-Top-5% & Oral Presentation.
  • [10/2022] arXiv version posted. The code is currently under cleaning. Please stay tuned for updates.

Periodic SSL: Brief Introduction for SimPer

From human physiology to environmental evolution, important processes in nature often exhibit meaningful and strong periodic or quasi-periodic changes. Due to their inherent label scarcity, learning useful representations for periodic tasks with limited or no supervision is of great benefit. Yet, existing self-supervised learning (SSL) methods overlook the intrinsic periodicity in data, and fail to learn representations that capture periodic or frequency attributes.

We present SimPer, a simple contrastive SSL regime for learning periodic information in data. To exploit the periodic inductive bias, SimPer introduces customized periodicity-invariant and periodicity-variant augmentations, periodic feature similarity measures, and a generalized contrastive loss for learning efficient and robust periodic representations.

We benchmark SimPer on common real-world tasks in human behavior analysis, environmental sensing, and healthcare domains. Further analysis also highlights its intriguing properties including better data efficiency, robustness to spurious correlations, and generalization to distribution shifts.

Apply SimPer on Customized Datasets

To apply SimPer on customized datasets, you will need to define the following key components. (Check out SimPer tutorial for RotatingDigits dataset.)

#1: Periodicity-Variant and Invariant Augmentations (see src/augmentation.py)

For (periodicity-)invariant augmentations, one could refer to SOTA contrastive methods (e.g., SimCLR). For periodicity-variant augmentations, we propose speed / frequency augmentation:

import tensorflow as tf
import tensorflow_probability as tfp

def arbitrary_speed_subsample(frames, speed, max_frame_len, img_size, channels, **kwargs):
    ...

    x_ref = tf.range(0, speed * (len(frames) - 0.5), speed, dtype=tf.float32)
    x_ref = tf.stack([x_ref] * (img_size * img_size * channels))
    new_frames = tfp.math.batch_interp_regular_1d_grid(
        x=x_ref,
        x_ref_min=[0] * (img_size * img_size * channels),
        x_ref_max=[len(frames)] * (img_size * img_size * channels),
        y_ref=tf.transpose(tf.reshape(frames, [len(frames), -1]))
    )
    sequence = tf.reshape(
        tf.transpose(new_frames), frames.shape.as_list()
    )[:tf.cast(max_frame_len, tf.int32)]

    ...

#2: Periodic Feature Similarity (see src/simper.py)

We provide practical instantiations to capture the periodic feature similarity, e.g., maximum cross-correlation:

import tensorflow as tf

@tf.function
def _max_cross_corr(feats_1, feats_2):
    feats_2 = tf.cast(feats_2, feats_1.dtype)
    feats_1 = feats_1 - tf.math.reduce_mean(feats_1, axis=-1, keepdims=True)
    feats_2 = feats_2 - tf.math.reduce_mean(feats_2, axis=-1, keepdims=True)

    min_N = min(feats_1.shape[-1], feats_2.shape[-1])
    padded_N = max(feats_1.shape[-1], feats_2.shape[-1]) * 2
    feats_1_pad = tf.pad(feats_1, tf.constant([[0, 0], [0, padded_N - feats_1.shape[-1]]]))
    feats_2_pad = tf.pad(feats_2, tf.constant([[0, 0], [0, padded_N - feats_2.shape[-1]]]))

    X = tf.signal.rfft(feats_1_pad) * tf.math.conj(tf.signal.rfft(feats_2_pad))
    power_norm = tf.cast(tf.math.reduce_std(feats_1, axis=-1, keepdims=True) *
                         tf.math.reduce_std(feats_2, axis=-1, keepdims=True), X.dtype)
    power_norm = tf.where(tf.equal(power_norm, 0), tf.ones_like(power_norm), power_norm)
    X = X / power_norm

    cc = tf.signal.irfft(X) / (min_N - 1)
    max_cc = tf.math.reduce_max(cc, axis=-1)
    return max_cc

#3: Generalized InfoNCE Loss over Continuous Targets (see src/simper.py)

First define label distance for continuous targets:

import tensorflow as tf

def label_distance(labels_1, labels_2, dist_fn='l1', label_temperature=0.1):
    if dist_fn == 'l1':
        dist_mat = - tf.math.abs(labels_1[:, :, None] - labels_2[:, None, :])
    elif dist_fn == 'l2':
        ...

    return tf.nn.softmax(dist_mat / label_temperature, axis=-1)

Then calculate a weighted loss over all augmented pairs (soft regression variant):

for features, labels in zip(all_features, all_labels):
    feat_dist = ...
    label_dist = ...
    criterion = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    loss += criterion(y_pred=feat_dist, y_true=label_dist)

Contact

If you have any questions, feel free to contact us through email ([email protected]) or Github issues. Enjoy!

More Repositories

1

imbalanced-regression

[ICML 2021, Long Talk] Delving into Deep Imbalanced Regression
Python
802
star
2

imbalanced-semi-self

[NeurIPS 2020] Semi-Supervision (Unlabeled Data) & Self-Supervision Improve Class-Imbalanced / Long-Tailed Learning
Python
736
star
3

multi-domain-imbalance

[ECCV 2022] Multi-Domain Long-Tailed Recognition, Imbalanced Domain Generalization, and Beyond
Python
127
star
4

SubpopBench

[ICML 2023] Change is Hard: A Closer Look at Subpopulation Shift
Python
95
star
5

ME-Net

[ICML 2019] ME-Net: Towards Effective Adversarial Robustness with Matrix Estimation
Python
51
star
6

SV-RL

[ICLR 2020, Oral] Harnessing Structures for Value-Based Planning and Reinforcement Learning
Python
34
star
7

OFDM

OFDM simulation project, using BPSK/QPSK and FIR filter.
MATLAB
26
star
8

shortcut-ood-fairness

[Nature Medicine] The Limits of Fair Medical Imaging AI In Real-World Generalization
Python
19
star
9

ImgSensingNet

[INFOCOM 2019] ImgSensingNet: UAV Vision Guided Aerial-Ground Air Quality Sensing System
Python
11
star
10

vlm-fairness

Demographic Bias of Vision-Language Foundation Models in Medical Imaging
Python
9
star
11

AQI_Dataset

A Dataset for fine-grained AQI distribution in typical 2D and 3D scenario.
4
star
12

Personal-Website

My personal website.
HTML
2
star
13

Data-Structure-and-Algorithm

Solutions for Data-Structure-and-Algorithm on POJ
C++
1
star
14

LaTeX-Templates

Templates for LaTeX files
TeX
1
star
15

Image_Processing

Methods to process image that is combined with Gaussian Noise and Obfuscation.
MATLAB
1
star
16

Microcomputer_Lab

Codes for Microcomputer Lab.
C
1
star
17

Dynamic_Webpage

A dynamic webpage implemented by Python, HTML/CSS, JavaScript, Node.js and MySQL
JavaScript
1
star