• Stars
    star
    618
  • Rank 72,605 (Top 2 %)
  • Language
    Python
  • License
    Other
  • Created about 3 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Robust fine-tuning of zero-shot models

Robust fine-tuning of zero-shot models

This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabriel Ilharco*, Jong Wook Kim, Mike Li, Simon Kornblith, Rebecca Roelofs, Raphael Gontijo-Lopes, Hannaneh Hajishirzi, Ali Farhadi, Hongseok Namkoong, Ludwig Schmidt.

TLDR: We fine-tune zero-shot models while preserving or improving OOD accuracy at no extra computational cost during fine-tuning or inference.

Abstract

Large pre-trained models such as CLIP or ALIGN offer consistent accuracy across a range of data distributions when performing zero-shot inference (i.e., without fine-tuning on a specific dataset). Although existing fine-tuning approaches substantially improve accuracy in-distribution, they often reduce out-of-distribution robustness. We address this tension by introducing a simple and effective method for improving robustness: ensembling the weights of the zero-shot and fine-tuned models (WiSE-FT). Compared to standard fine-tuning, WiSE-FT provides large accuracy improvements out-of-distribution, while preserving high in-distribution accuracy. On ImageNet (in-distribution) and five derived distribution shifts, WiSE-FT improves out-of-distribution accuracy by 4 to 6 percentage points (pp) over prior work while increasing in-distribution accuracy by 1.6 pp. WiSE-FT achieves similarly large robustness improvements (2 to 23 pp) on a diverse set of six further distribution shifts, and in-distribution accuracy gains of 0.8 to 3.3 pp compared to standard fine-tuning on seven commonly used transfer learning datasets. These improvements come at no additional computational cost during fine-tuning or inference.

Summary figure

figure1

Code

Overview

WiSE-FT can be implemented in a few lines of code in addition to standard fine-tuning, as shown below. See src/wise_ft.py for more details.

# Load models
zeroshot = ImageClassifier.load(zeroshot_checkpoint)
finetuned = ImageClassifier.load(finetuned_checkpoint)
theta_0 = zeroshot.state_dict()
theta_1 = finetuned.state_dict()

# make sure checkpoints are compatible
assert set(theta_0.keys()) == set(theta_1.keys())

# interpolate between checkpoints with mixing coefficient alpha
theta = {
    key: (1-alpha) * theta_0[key] + alpha * theta_1[key]
    for key in theta_0.keys()
}

# update the model acccording to the new weights
finetuned.load_state_dict(theta)

# evaluate
evaluate(finetuned, args)

Install dependencies

conda env create
conda activate wiseft

Add directory to PYTHONPATH:

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

Download data

When necessary, please refer to datasets.md for instructions on how to download datasets.

Run WiSE-FT

Sample command when zeroshot and fine-tuned models are available:

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Sample command for running WiSE-FT from scratch using ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Note: the flag --freeze-encoder controls whether only a linear classifier is fine-tuned, or if all weights are fine-tuned (end-to-end).

Plotting results

Sample command for generating a scatter plot:

python src/scatter_plot.py  \
    --eval-datasets=ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --results-db=results.jsonl  \
    --save plots

We show samples of expected behavior below when running the commands above using ViT-B/16 (models can be downloaded here):

ImageNet-Sketch         ImageNet-A

ImageNet-R         ImageNetV2

ObjectNet

Citing

If you found this repository useful, please consider citing:

@article{wortsman2021robust,
  title={Robust fine-tuning of zero-shot models},
  author={Wortsman, Mitchell and Ilharco, Gabriel and Kim, Jong Wook and Li, Mike and Kornblith, Simon and Roelofs, Rebecca and Gontijo-Lopes, Raphael and Hajishirzi, Hannaneh and Farhadi, Ali and Namkoong, Hongseok and Schmidt, Ludwig},
  journal={arXiv preprint arXiv:2109.01903},
  note={\url{https://arxiv.org/abs/2109.01903}},
  year={2021}
}

More Repositories

1

open_clip

An open source implementation of CLIP.
Python
9,941
star
2

open_flamingo

An open-source framework for training large multimodal models.
Python
3,716
star
3

dclm

DataComp for Language Models
HTML
1,073
star
4

MINT-1T

MINT-1T: A one trillion token multimodal interleaved dataset.
749
star
5

datacomp

DataComp: In search of the next generation of multimodal datasets
Python
628
star
6

open_lm

A repository for research on medium sized language models.
Python
475
star
7

model-soups

Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time
Python
412
star
8

task_vectors

Editing Models with Task Arithmetic
Python
397
star
9

open-diffusion

Simple large-scale training of stable diffusion with multi-node support.
Python
120
star
10

scaling

Language models scale reliably with over-training and on downstream tasks
Jupyter Notebook
90
star
11

patching

Patching open-vocabulary models by interpolating weights
Python
87
star
12

VisIT-Bench

Python
46
star
13

imagenet-captions

Release of ImageNet-Captions
45
star
14

tableshift

A benchmark for distribution shift in tabular data
Python
38
star
15

clip_quality_not_quantity

Python
28
star
16

rtfm

Research on Tabular Foundation Models
Python
20
star
17

dataset2metadata

Python
19
star
18

spark-commoncrawl

Jupyter Notebook
6
star
19

datacomp_site

HTML
6
star
20

tabliblib

A Python library for processing and filtering TabLib
Python
5
star
21

webdataset-resharder

Efficiently process webdatasets
Python
4
star
22

imagenet-applications-transfer

Python
2
star
23

au21

Jupyter Notebook
1
star
24

advancedml-sp23

CSS
1
star