• Stars
    star
    128
  • Rank 281,044 (Top 6 %)
  • Language
    Python
  • Created over 4 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning

BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning

PyTorch implementation of "Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning" by J.B. Grill et al.

Link to paper

This repository includes a practical implementation of BYOL with:

  • Distributed Data Parallel training
  • Benchmarks on vision datasets (CIFAR-10 / STL-10)
  • Support for PyTorch <= 1.5.0

Open BYOL in Google Colab Notebook

Open In Colab

Results

These are the top-1 accuracy of linear classifiers trained on the (frozen) representations learned by BYOL:

Method Batch size Image size ResNet Projection output dim. Pre-training epochs Optimizer STL-10 CIFAR-10
BYOL + linear eval. 192 224x224 ResNet18 256 100 Adam _ 0.832
Logistic Regression - - - - - - 0.358 0.389

Installation

git clone https://github.com/spijkervet/byol --recurse-submodules -j8
pip3 install -r requirements.txt
python3 main.py

Usage

Using a pre-trained model

The following commands will train a logistic regression model on a pre-trained ResNet18, yielding a top-1 accuracy of 83.2% on CIFAR-10.

curl https://github.com/Spijkervet/BYOL/releases/download/1.0/resnet18-CIFAR10-final.pt -L -O
rm features.p
python3 logistic_regression.py --model_path resnet18-CIFAR10-final.pt

Pre-training

To run pre-training using BYOL with the default arguments (1 node, 1 GPU), use:

python3 main.py

Which is equivalent to:

python3 main.py --nodes 1 --gpus 1

The pre-trained models are saved every n epochs in *.pt files, the final model being model-final.pt

Finetuning

Finetuning a model ('linear evaluation') on top of the pre-trained, frozen ResNet model can be done using:

python3 logistic_regression.py --model_path=./model_final.pt

With model_final.pt being file containing the pre-trained network from the pre-training stage.

Multi-GPU / Multi-node training

Use python3 main.py --gpus 2 to train e.g. on 2 GPU's, and python3 main.py --gpus 2 --nodes 2 to train with 2 GPU's using 2 nodes. See https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html for an excellent explanation.

Arguments

--image_size, default=224, "Image size"
--learning_rate, default=3e-4, "Initial learning rate."
--batch_size, default=42, "Batch size for training."
--num_epochs, default=100, "Number of epochs to train for."
--checkpoint_epochs, default=10, "Number of epochs between checkpoints/summaries."
--dataset_dir, default="./datasets", "Directory where dataset is stored.",
--num_workers, default=8, "Number of data loading workers (caution with nodes!)"
--nodes, default=1, "Number of nodes"
--gpus, default=1, "number of gpus per node"
--nr, default=0, "ranking within the nodes"

More Repositories

1

SimCLR

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al.
Python
742
star
2

CLMR

Official PyTorch implementation of Contrastive Learning of Musical Representations
Python
302
star
3

torchaudio-augmentations

Audio transformations library for PyTorch
Python
216
star
4

eurovision-dataset

The Eurovision Song Contest Dataset is a freely-available dataset containing audio features, metadata, contest ranking and voting data of 1735 songs that have competed in the Eurovision Song Contests between 1956 and 2023.
Python
86
star
5

contrastive-predictive-coding

PyTorch implementation of Representation Learning with Contrastive Predictive Coding by Van den Oord et al. (2018)
Python
80
star
6

godfather

The Godfather resource for GTA:Network's online modification for GTA:V. The mod can be downloaded at: https://gtanet.work
JavaScript
30
star
7

Context-Aware-Sequential-Recommendation

This is the Github repository containing the code for the Context-Aware Sequential Recommendation project for the Information Retrieval 2 course at the University of Amsterdam
Python
11
star
8

crypto-data-scraper

Crypto data scraper using Websockets and MongoDB to receive real-time data from cryptocurrency exchanges and save it for historic analysis (machine learning, etc).
Python
10
star
9

gpt-2-lyrics

Using GPT-2 to generate lyrics
Python
6
star
10

midi-controller

MIDI controller made with React and Flask, for use with Ableton or other DAWs
JavaScript
5
star
11

atom-latex-online

Atom Latex Online package
JavaScript
3
star
12

thesis

My Master's Thesis
TeX
3
star
13

sat_sudoku_solver

SAT solver for Sudoku's for the UvA MSc AI course Knowledge Representation
Jupyter Notebook
2
star
14

flask-socketio-bootstrap4-boilerplate

Boilerplate for a Flask webserver, with SocketIO and Bootstrap 4 integrated.
JavaScript
2
star
15

global_food_prices

Data visualization project for UvA on the Global Food Prices dataset.
HTML
2
star
16

weebo

An intelligent personal assistant inspired by the Weebo robot from the popular 1997 movie Flubber.
JavaScript
2
star
17

search_engine

Search engine for arxiv submissions
JavaScript
2
star
18

qualitative_reasoning

Qualitative Reasoning assignment VU
Python
2
star
19

personal-website

My personal website written in the Gatsby framework with a Ghost backend
JavaScript
1
star
20

dutch_jurisdiction_elastic_search

Elastic Search for Dutch jurisdiction archive (rechtspraak.nl)
Python
1
star
21

juce-simple-eq

Simple EQ made in JUCE 6
C++
1
star
22

SETUP-smartlappen

SETUP x Smartlappen project
HTML
1
star
23

homelab

My Homelab built on Docker
Shell
1
star
24

ai-music-presentation

Presentation on Music an AI (Mon 22 January 2018)
Jupyter Notebook
1
star