• Stars
    star
    663
  • Rank 67,991 (Top 2 %)
  • Language
    Jupyter Notebook
  • License
    Apache License 2.0
  • Created almost 3 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

iBOT 🤖: Image BERT Pre-Training with Online Tokenizer (ICLR 2022)

Image BERT Pre-Training with iBOT iBOT Icon

PWC
PWC
PWC

Official PyTorch implementation and pre-trained models for paper iBOT: Image BERT Pre-Training with Online Tokenizer.

[arXiv] [Colab] [BibTex]

iBOT framework

iBOT is a novel self-supervised pre-training framework that performs masked image modeling with self-distillation. iBOT pre-trained model shows local semantic features, which helps the model transfer well to downstream tasks both at a global scale and a local scale. For example, iBOT achieves strong performance on COCO object detection (51.2 box AP and 44.2 mask AP) and ADE20K semantic segmentation (50.0 mIoU) with vanilla ViT-B/16. iBOT can also extract semantic-meaningful local parts, like dog's ear 🐶.

News 🎉

  • January 2022 - The paper is accepted by ICLR 2022.
  • Update - ViT-L/16 with ImageNet-1K pre-training achieves 81.0% in linear probing accuracy. ViT-L/16 with ImageNet-22K pre-training achieves 87.8% in 512x fine-tuning accuracy.
  • Update - Random masking with a relatively larger prediction ratio performs slighly better than block-wise masking. For example, ViT-B/16 achieves an 84.1% fine-tuning accuracy and a 51.5 box AP in object detection.
  • December 2021 - Release the code and pre-trained models.
  • November 2021 - Release the pre-print on arXiv.

Installation

See installation structions for details.

One-Line Command by Using run.sh

We provide run.sh with which you can complete the pre-training + fine-tuning experiment cycle in an one-line command.

Arguments

  • TYPE is named by the rule of dataset_task. For example, pre-training on ImageNet-1K has a TYPE of imagenet_pretrain and linear probing evalution on ImageNet-1K has a TYPE of imagenet_linear. Different types of task can be appended in one command.
  • JOB_NAME is the customized job name to distinguish from different groups of experiments.
  • ARCH is the architecture of the pre-trained models.
  • KEY chooses which pre-trained model to be evaluated and can be set as either teacher (generally better) or student for one model.
  • GPUS is GPUs needed for each node, and will be clamped by MAX_GPUS (default as 8).
  • Other additional arguments can directly appended after these required ones. For example, --lr 0.001.

For example, the following command will automatically evaluate the models on K-NN and linear probing benchmark after the pre-training with student and teacher model distributed across 2 nodes:

TOTAL_NODES=2 NODE_ID=0 ./run.sh imagenet_pretrain+imagenet_knn+imagenet_linear vit_small student,teacher 16 // the first node
TOTAL_NODES=2 NODE_ID=1 ./run.sh imagenet_pretrain+imagenet_knn+imagenet_linear vit_small student,teacher 16 // the second node

Training

For a glimpse at the full documentation of iBOT pre-training, please run:

python main_ibot.py --help

iBOT Pre-Training with ViTs

To start the iBOT pre-training with Vision Transformer (ViT), simply run the following commands. JOB_NAME is a customized argument to distinguish different experiments and this will automatically save checkpoints into the seperate folders.

./run.sh imagenet_pretrain $JOB_NAME vit_{small,base,large} teacher {16,24,64}

The exact arguments to reproduce the models presented in our paper can be found in the args column of the pre-trained models. We also provide the logs for pre-training to help reproducibility.

For example, run iBOT with ViT-S/16 network on two nodes with 8 GPUs for 800 epochs with the following command. The resulting checkpoint should reach 75.2% on k-NN accuracy, 77.9% on linear probing accuracy, and 82.3% on fine-tuning accuracy.

./run.sh imagenet_pretrain $JOB_NAME vit_small teacher 16 \
  --teacher_temp 0.07 \
  --warmup_teacher_temp_epochs 30 \
  --norm_last_layer false \
  --epochs 800 \
  --batch_size_per_gpu 64 \
  --shared_head true \
  --out_dim 8192 \
  --local_crops_number 10 \
  --global_crops_scale 0.25 1 \
  --local_crops_scale 0.05 0.25 \
  --pred_ratio 0 0.3 \
  --pred_ratio_var 0 0.2

iBOT Pre-Training with Swins

This code also works for training iBOT on Swin Transformer (Swin). In the paper, we only conduct experiments on Swin-T with different window sizes:

./run.sh imagenet_pretrain $JOB_NAME swin_tiny teacher {16,40} \
  --patch_size 4 \
  --window_size {7,14}

For example, run iBOT with Swin-T/14 network on five nodes with 8 GPUS for 300 epochs with the following command. The resulting checkpoint should reach 76.2% on k-NN accuracy, 79.3% on linear probing accuracy.

./run.sh imagenet_pretrain $JOB_NAME swin_tiny teacher 40 \
  --teacher_temp 0.07 \
  --warmup_teacher_temp_epochs 30 \
  --norm_last_layer false \
  --epochs 300 \
  --batch_size_per_gpu 26 \
  --shared_head true \
  --out_dim 8192 \
  --local_crops_number 10 \
  --global_crops_scale 0.25 1 \
  --local_crops_scale 0.05 0.25 \
  --pred_ratio 0 0.3 \
  --pred_ratio_var 0 0.2 \
  --pred_start_epoch 50 \
  --patch_size 4 \
  --window_size 14 

Pre-Trained Models

You can choose to download only the weights of the pre-trained backbone used for downstream tasks, and the full ckpt which contains backbone and projection head weights for both student and teacher networks. For the backbone, s denotes that the student network is selected while t denotes that the teacher network is selected. PS denotes prediction shape.

Arch. Par. PS k-NN Lin. Fin. download
ViT-S/16 21M Block 75.2% 77.9% 82.3% backbone (t) full ckpt args logs
Swin-T/7 28M Block 75.3% 78.6% \ backbone (t) full ckpt args logs
Swin-T/14 28M Block 76.2% 79.3% \ backbone (t) full ckpt args logs
ViT-B/16 85M Block 77.1% 79.5% 84.0% backbone (t) full ckpt args logs
ViT-B/16 85M Rand 77.3% 79.8% 84.1% backbone (t) full ckpt args logs
ViT-L/16 307M Block 78.0% 81.0% 84.8% backbone (t) full ckpt args logs
ViT-L/16 307M Rand 77.7% 81.3% 85.0% backbone (t) full ckpt args logs

We also provide the ViT-{B,L}/16 model pre-trained on ImageNet-22K dataset.

Arch. Par. PS k-NN Lin. Fin. download
256 384 512
ViT-B/16 85M Block 71.1% 79.0% 84.4% \ \ backbone (s) full ckpt args logs
ViT-L/16 307M Block 72.9% 82.3% 86.6% 87.5% 87.8% backbone (s) full ckpt args logs

To extract the backbone from the full checkpoint by yourself, please run the following command where KEY being either student or teacher.

WEIGHT_FILE=$OUTPUT_DIR/checkpoint_$KEY.pth

python extract_backbone_weights.py \
  --checkpoint_key $KEY \
  $PRETRAINED \
  $WEIGHT_FILE \

Downstream Evaluation

See Evaluating iBOT on Downstream Tasks for details.

Property Analysis

See Analyzing iBOT's Properties for robustness test and visualizing self-attention map:

iBOT Global Pattern Layout

or extracting sparse correspondence pairs between two images:

iBOT Global Pattern Layout

We also provide a Colab page 📑 you can play around with iBOT pre-trained models.

Extracting Semantic Patterns

We extract top-k numbered local classes based on patch tokens with their corresponding patches and contexts by running the following command. We indentify very diverse behaviour like shared low-level textures and high-level semantics.

python3 -m torch.distributed.launch --nproc_per_node=8 \
    --master_port=${MASTER_PORT:-29500} \
    analysis/extract_pattern/extract_topk_cluster.py \
    --pretrained_path $PRETRAINED \
    --checkpoint {student,teacher} \
    --type patch \
    --topk 36 \
    --patch_window 5 \
    --show_pics 20 \
    --arch vit_small \
    --save_path memory_bank_patch.pth \
    --data_path data/imagenet/val
iBOT Local Part-Level Pattern Layout

The script also supports to extract the patern layout on the [CLS] token, which is actually doing clustering or unsupervised classification. This property is not induced by MIM objective since we also spot this feature on DINO.

python3 -m torch.distributed.launch --nproc_per_node=8 \
    --master_port=${MASTER_PORT:-29500} \
    analysis/extract_pattern/extract_topk_cluster.py \
    --pretrained_path $PRETRAINED \
    --checkpoint {student,teacher} \
    --type cls \
    --topk 36 \
    --show_pics 20 \
    --arch vit_small \
    --save_path memory_bank_cls.pth \
    --data_path data/imagenet/val
iBOT Global Pattern Layout

Acknowledgement

This repository is built using the DINO repository and the BEiT repository.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citing iBOT

If you find this repository useful, please consider giving a star and citation:

@article{zhou2021ibot,
  title={iBOT: Image BERT Pre-Training with Online Tokenizer},
  author={Zhou, Jinghao and Wei, Chen and Wang, Huiyu and Shen, Wei and Xie, Cihang and Yuille, Alan and Kong, Tao},
  journal={International Conference on Learning Representations (ICLR)},
  year={2022}
}

More Repositories

1

IconPark

🍎Transform an SVG icon into multiple themes, and generate React icons,Vue icons,svg icons
TypeScript
8,298
star
2

xgplayer

A HTML5 video player with a parser that saves traffic
JavaScript
8,260
star
3

sonic

A blazingly fast JSON serializing & deserializing library
Assembly
6,870
star
4

monoio

Rust async runtime based on io-uring.
Rust
3,864
star
5

byteps

A high performance and generic framework for distributed DNN training
Python
3,603
star
6

lightseq

LightSeq: A High Performance Library for Sequence Processing and Generation
C++
3,193
star
7

ByteX

ByteX is a bytecode plugin platform based on Android Gradle Transform API and ASM. 字节码插件开发平台
Java
2,865
star
8

Elkeid

Elkeid is an open source solution that can meet the security requirements of various workloads such as hosts, containers and K8s, and serverless. It is derived from ByteDance's internal best practices.
Go
2,226
star
9

AlphaPlayer

AlphaPlayer is a video animation engine.
Java
2,181
star
10

scene

Android Single Activity Framework compatible with Fragment.
Java
2,097
star
11

bhook

🔥 ByteHook is an Android PLT hook library which supports armeabi-v7a, arm64-v8a, x86 and x86_64.
C
2,073
star
12

flutter_ume

UME is an in-app debug kits platform for Flutter. Produced by Flutter Infra team of ByteDance
Dart
2,053
star
13

terarkdb

A RocksDB compatible KV storage engine with better performance
C++
2,044
star
14

btrace

🔥🔥 btrace(AKA RheaTrace) is a high performance Android trace tool which is based on Perfetto, it support to define custom events automatically during building apk and using bhook to provider more native events like Render/Binder/IO etc.
Kotlin
1,913
star
15

gopkg

Universal Utilities for Go
Go
1,704
star
16

android-inline-hook

🔥 ShadowHook is an Android inline hook library which supports thumb, arm32 and arm64.
C
1,660
star
17

bitsail

BitSail is a distributed high-performance data integration engine which supports batch, streaming and incremental scenarios. BitSail is widely used to synchronize hundreds of trillions of data every day.
Java
1,627
star
18

go-tagexpr

An interesting go struct tag expression syntax for field validation, etc.
Go
1,470
star
19

GiantMIDI-Piano

Python
1,431
star
20

appshark

Appshark is a static taint analysis platform to scan vulnerabilities in an Android app.
Kotlin
1,363
star
21

AabResGuard

The tool of obfuscated aab resources.(Android app bundle资源混淆工具)
Java
1,307
star
22

piano_transcription

Python
1,247
star
23

CodeLocator

Kotlin
1,163
star
24

BoostMultiDex

BoostMultiDex is a solution for quickly loading multiple dex files on low Android version devices (4.X and below, SDK <21).
Java
1,106
star
25

music_source_separation

Python
1,039
star
26

Fastbot_Android

Fastbot(2.0) is a model-based testing tool for modeling GUI transitions to discover app stability problems
C++
1,031
star
27

SALMONN

SALMONN: Speech Audio Language Music Open Neural Network
Python
1,000
star
28

memory-leak-detector

C
919
star
29

fedlearner

A multi-party collaborative machine learning framework
Python
892
star
30

monolith

ByteDance's Recommendation System
Python
844
star
31

sonic-cpp

A fast JSON serializing & deserializing library, accelerated by SIMD.
C++
811
star
32

godlp

sensitive information protection toolkit
Go
770
star
33

MVDream

Multi-view Diffusion for 3D Generation
Python
744
star
34

res-adapter

Official implementation of "ResAdapter: Domain Consistent Resolution Adapter for Diffusion Models".
Python
724
star
35

bytemd

ByteMD v1 repository
TypeScript
679
star
36

tailor

C
669
star
37

RealRichText

A Tricky Solution for Implementing Inline-Image-In-Text Feature in Flutter.
Dart
658
star
38

guide

A new feature guide component by react 🧭
TypeScript
651
star
39

mockey

a simple and easy-to-use golang mock library
Go
622
star
40

magic-microservices

Make Web Components easier and powerful!😘
TypeScript
570
star
41

Fastbot_iOS

About Fastbot(2.0) is a model-based testing tool for modeling GUI transitions to discover app stability problems
Objective-C
553
star
42

flow-builder

A highly customizable streaming flow builder.
TypeScript
526
star
43

MVDream-threestudio

3D generation code for MVDream
Python
473
star
44

effective_transformer

Running BERT without Padding
C++
457
star
45

ByteTransformer

optimized BERT transformer inference on NVIDIA GPU. https://arxiv.org/abs/2210.03052
C++
449
star
46

Next-ViT

Python
426
star
47

matxscript

A high-performance, extensible Python AOT compiler.
C++
408
star
48

byteir

A model compilation solution for various hardware
MLIR
362
star
49

syllepsis

Syllepsis is an out-of-the-box rich text editor.
TypeScript
355
star
50

uss

This is the PyTorch implementation of the Universal Source Separation with Weakly labelled Data.
Python
324
star
51

OMGD

Online Multi-Granularity Distillation for GAN Compression (ICCV2021)
Python
323
star
52

neurst

Neural end-to-end Speech Translation Toolkit
Python
298
star
53

danmu.js

HTML5 danmu (danmaku) plugin for any DOM element
JavaScript
292
star
54

vArmor

vArmor is a cloud native container sandbox system based on AppArmor/BPF/Seccomp. It also includes multiple built-in protection rules that are ready to use out of the box.
Go
263
star
55

particle-sfm

ParticleSfM: Exploiting Dense Point Trajectories for Localizing Moving Cameras in the Wild. ECCV 2022.
C++
263
star
56

CloudShuffleService

Cloud Shuffle Service(CSS) is a general purpose remote shuffle solution for compute engines, including Spark/Flink/MapReduce.
Java
245
star
57

lynx-llm

paper: https://arxiv.org/abs/2307.02469 page: https://lynx-llm.github.io/
Python
227
star
58

g3

Enterprise-oriented Generic Proxy Solutions
Rust
227
star
59

xgplayer-vue

Vue component for xgplayer, a HTML5 video player with a parser that saves traffic
JavaScript
219
star
60

DEADiff

[CVPR 2024] Official implementation of "DEADiff: An Efficient Stylization Diffusion Model with Disentangled Representations"
Python
209
star
61

flux

A fast communication-overlapping library for tensor parallelism on GPUs.
C++
201
star
62

trace-irqoff

Interrupts-off or softirqs-off latency tracer
C
195
star
63

ParaGen

ParaGen is a PyTorch deep learning framework for parallel sequence generation.
Python
186
star
64

ByteMLPerf

AI Accelerator Benchmark focuses on evaluating AI Accelerators from a practical production perspective, including the ease of use and versatility of software and hardware.
Python
181
star
65

MoMA

MoMA: Multimodal LLM Adapter for Fast Personalized Image Generation
Jupyter Notebook
177
star
66

AWERTL

An non-invasive iOS framework for quickly adapting Right-To-Left style UI
Objective-C
175
star
67

Bytedance-UnionAD

Ruby
170
star
68

keyhouse

Keyhouse is a skeleton of general-purpose Key Management System written in Rust.
Rust
163
star
69

react-model

The next generation state management library for React
TypeScript
162
star
70

LargeBatchCTR

Large batch training of CTR models based on DeepCTR with CowClip.
Python
162
star
71

ic_flow_platform

IFP (ic flow platform) is an integrated circuit design flow platform, mainly used for IC process specification management and data flow contral.
Python
154
star
72

DanmakuRenderEngine

DanmakuRenderEngine is a lightweight and scalable Android danmaku library. 轻量级高扩展安卓弹幕渲染引擎
Kotlin
149
star
73

primus

Java
148
star
74

diat

A CLI tool to help with diagnosing Node.js processes basing on inspector.
JavaScript
146
star
75

coconut_cvpr2024

Jupyter Notebook
143
star
76

Hammer

An efficient toolkit for training deep models.
Python
138
star
77

ns-x

An easy-to-use, flexible network simulator library in Go.
Go
116
star
78

pv3d

Python
113
star
79

fc-clip

This repo contains the code for our paper Convolutions Die Hard: Open-Vocabulary Segmentation with Single Frozen Convolutional CLIP
Python
109
star
80

RLFN

Winner of runtime track in NTIRE 2022 challenge on Efficient Super-Resolution
Python
106
star
81

DCFrame

DCFrame is a Swift UI collection framework, which can easily create complex UI.
Swift
100
star
82

trace-noschedule

Trace noschedule thread
C
99
star
83

decoupleQ

A quantization algorithm for LLM
Cuda
99
star
84

tar-wasm

A faster experimental wasm-based tar implementation for browsers.
Rust
95
star
85

TWIST

Official codes: Self-Supervised Learning by Estimating Twin Class Distribution
Python
95
star
86

magic-portal

⚡ A blazing fast micro-component and micro-frontend solution uses web-components under the hood.
TypeScript
91
star
87

xgplayer-react

React component for xgplayer, a HTML5 video player with a parser that saves traffic
JavaScript
84
star
88

fe-foundation

UI Foundation for React Hooks and Vue Composition Api
TypeScript
80
star
89

nnproxy

Scalable NameNode RPC Proxy for HDFS Federation
Java
79
star
90

dbatman

Go
74
star
91

Elkeid-HUB

Elkeid HUB is a rule/event processing engine maintained by the Elkeid Team that supports streaming/offline (not yet supported by the community edition) data processing. The original intention is to solve complex data/event processing and external system linkage requirements through standardized rules.
Python
74
star
92

FreeSeg

Python
69
star
93

pull_to_refresh

Flutter pull_to_refresh widget
Dart
67
star
94

Jeddak-DPSQL

DPSQL (Privacy Protection SQL Query Service) - This project is a microservice Middleware located between the database engine ( Hive , Clickhouse , etc.) and the application system. It provides transparent SQL query result desensitization capabilities.
Python
62
star
95

terark-zip

A data structure and algorithm library built for TerarkDB
C++
62
star
96

trace-runqlat

C
61
star
97

ipmb

An interprocess message bus system built in Rust.
Rust
60
star
98

X-Portrait

Source code for the SIGGRAPH 2024 paper "X-Portrait: Expressive Portrait Animation with Hierarchical Motion Attention"
Python
59
star
99

kernel

ByteDance kernel for use on cloud.
C
57
star
100

scroll_kit

Dart
56
star