• Stars
    star
    218
  • Rank 181,805 (Top 4 %)
  • Language
    Jupyter Notebook
  • License
    MIT License
  • Created about 4 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Wrapper for a PyTorch classifier which allows it to output prediction sets. The sets are theoretically guaranteed to contain the true class with high probability (via conformal prediction).

Open In Colab

Paper

Uncertainty Sets for Image Classifiers using Conformal Prediction

@article{angelopoulos2020sets,
  title={Uncertainty Sets for Image Classifiers using Conformal Prediction},
  author={Angelopoulos, Anastasios N and Bates, Stephen and Malik, Jitendra and Jordan, Michael I},
  journal={arXiv preprint arXiv:2009.14193},
  year={2020}
}

Basic Overview

This codebase modifies any PyTorch classifier to output a predictive set which provably contains the true class with a probability you specify. It uses a method called Regularized Adaptive Prediction Sets (RAPS), which we introduce in our accompanying paper. The procedure is as simple and fast as Platt scaling, but provides a formal finite-sample coverage guarantee for every model and dataset.

Set-valued classifier.

Prediction set examples on Imagenet. we show three examples of the class fox squirrel along with 95% prediction sets generated by our method to illustrate how set size changes based on the difficulty of a test-time image.

Google Colab

We have written a Colab which allows you to explore RAPS and conformal classification. You don't have to install anything to run the Colab. The notebook will lead you through constructing predictive sets from a pretrained model. You can also visualize examples from ImageNet along with their corresponding RAPS sets and play with the regularization parameters.

You can access the Colab by clicking the shield below.

Open In Colab

Usage

If you'd like to use our code in your own projects and reproduce our experiments, we provide the tools below. Note that although our codebase isn't a package, it's easy to use it like a package, and we do so in the Colab notebook above.

From the root directory, install the dependencies and run our example by executing:

git clone https://github.com/aangelopoulos/conformal-classification
cd conformal-classification
conda env create -f environment.yml
conda activate conformal
python example.py 'path/to/imagenet/val/'

Look inside example.py for a minimal example that modifies a pretrained classifier to output 90% prediction sets.

If you'd like to use our codebase on your own model, first place this at the top of your file:

from conformal.py import *
from utils.py import *

Then create a holdout set for conformal calibration using a line like:

calib, val = random_split(mydataset, [num_calib,total-num_calib])

Finally, you can create the model

model = ConformalModel(model, calib_loader, alpha=0.1, lamda_criterion='size')

The ConformalModel object takes a boolean flag randomized. When randomized=True, at test-time, the sets will not be randomized. This will lead to conservative coverage, but deterministic behavior.

The ConformalModel object takes a second boolean flag allow_zero_sets. When allow_zero_sets=True, at test-time, sets of size zero are disallowed. This will lead to conservative coverage, but no zero-size sets.

See the discussion below for picking alpha, kreg, and lamda manually.

Reproducing Our Results

The output of example.py should be:

Begin Platt scaling.
Computing logits for model (only happens once).
100%|███████████████████████████████████████| 79/79 [02:24<00:00,  1.83s/it]
Optimal T=1.1976691484451294
Model calibrated and conformalized! Now evaluate over remaining data.
N: 40000 | Time: 1.686 (2.396) | Cvg@1: 0.766 (0.782) | Cvg@5: 0.969 (0.941) | Cvg@RAPS: 0.891 (0.914) | Size@RAPS: 2.953 (2.982)
Complete!

The values in parentheses are running averages. The preceding values are only for the most recent batch. The timing values will be different on your system, but the rest of the numbers should be exactly the same. The progress bar may print over many lines if your terminal window is small.

The expected outputs of the experiments are stored in experiments/outputs, and they are exactly identical to the results reported in our paper. You can reproduce the results by executing the python scripts in './experiments/' after you have installed our dependencies. For Table 2, we used the matched-frequencies version of ImageNet-V2.

Picking alpha, kreg, and lamda

alpha is the maximum proportion of errors you are willing to tolerate. The target coverage is therefore 1-alpha. A smaller alpha will usually lead to larger sets, since the desired coverage is more stringent.

We have included two optimal procedures for picking 'kreg' and 'lamda'. If you want sets with small size, set 'lamda_criterion='size''. If you want sets that approximate conditional coverage, set 'lamda_criterion='adaptiveness''.

License

MIT License

More Repositories

1

conformal-prediction

Lightweight, useful implementation of conformal prediction on real data.
Jupyter Notebook
707
star
2

ppi_py

A package for statistically rigorous scientific discovery using machine learning. Implements prediction-powered inference.
Python
193
star
3

rcps

Official codebase for "Distribution-Free, Risk-Controlling Prediction Sets"
Python
84
star
4

conformal-time-series

Conformal prediction for time-series applications.
Jupyter Notebook
81
star
5

prediction-powered-inference

A statistical toolkit for scientific discovery using machine learning
Jupyter Notebook
69
star
6

event_based_gaze_tracking

Dataset release for Event-Based, Near-Eye Gaze Tracking Beyond 10,000 Hz
Python
62
star
7

ltt

Learn then Test: Calibrating Predictive Algorithms to Achieve Risk Control
Jupyter Notebook
60
star
8

conformal-risk

Conformal prediction for controlling monotonic risk functions. Simple accompanying PyTorch code for conformal risk control in computer vision and natural language processing.
Python
55
star
9

im2im-uq

Image-to-image regression with uncertainty quantification in PyTorch. Take any dataset and train a model to regress images to images with rigorous, distribution-free uncertainty quantification.
Python
50
star
10

cfr-covid-19

Implementation of https://arxiv.org/abs/2003.08592
R
17
star
11

private_prediction_sets

Wrap around any model to output differentially private prediction sets with finite sample validity on any dataset.
Python
17
star
12

online-conformal-decaying

Jupyter Notebook
3
star
13

conformal-triage

Jupyter Notebook
2
star