• Stars
    star
    169
  • Rank 224,453 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 2 years ago
  • Updated 6 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A fast, effective data attribution method for neural networks in PyTorch

arXiv PyPI version Documentation Status Code style: black

TRAK: Attributing Model Behavior at Scale

[docs & tutorials] [blog post] [website]

In our paper, we introduce a new data attribution method called TRAK (Tracing with the Randomly-Projected After Kernel). Using TRAK, you can make accurate counterfactual predictions (e.g., answers to questions of the form โ€œwhat would happen to this prediction if these examples are removed from the training set?"). Computing data attribution with TRAK is 2-3 orders of magnitude cheaper than comparably effective methods, e.g., see our evaluation on:

Main figure

Usage

Check our docs for more detailed examples and tutorials on how to use TRAK. Below, we provide a brief blueprint of using TRAK's API to compute attribution scores.

Make a TRAKer instance

from trak import TRAKer

model, checkpoints = ...
train_loader = ...

traker = TRAKer(model=model, task='image_classification', train_set_size=...)

Compute TRAK features on training data

for model_id, checkpoint in enumerate(checkpoints):
  traker.load_checkpoint(checkpoint, model_id=model_id)
  for batch in loader_train:
      # batch should be a tuple of inputs and labels
      traker.featurize(batch=batch, ...)
traker.finalize_features()

Compute TRAK scores for target examples

targets_loader = ...

for model_id, checkpoint in enumerate(checkpoints):
  traker.start_scoring_checkpoint(checkpoint,
                                  model_id=model_id,
                                  exp_name='test',
                                  num_targets=...)
  for batch in targets_loader:
    traker.score(batch=batch, ...)

scores = traker.finalize_scores(exp_name='test')

Then, you can use the compute TRAK scores to analyze your model's behavior. For example, here are the most (positively and negatively) impactful examples for a ResNet18 model trained on ImageNet for three targets from the ImageNet validation set: ImageNet Figure

Check out the quickstart for a complete ready-to-run example notebook. You can also find several end-to-end examples in the examples/ directory.

Contributing

We welcome contributions to this project! Please see our contributing guidelines for more information.

Citation

If you use this code in your work, please cite using the following BibTeX entry:

@inproceedings{park2023trak,
  title = {TRAK: Attributing Model Behavior at Scale},
  author = {Sung Min Park and Kristian Georgiev and Andrew Ilyas and Guillaume Leclerc and Aleksander Madry},
  booktitle = {International Conference on Machine Learning (ICML)},
  year = {2023}
}

Installation

To install the version of our package which contains a fast, custom CUDA kernel for the JL projection step, use

pip install traker[fast]

You will need compatible versions of gcc and CUDA toolkit to install it. See the installation FAQs for tips regarding this. To install the basic version of our package that requires no compilation, use

pip install traker

Questions?

Please send an email to [email protected]

Maintainers

Kristian Georgiev
Andrew Ilyas
Sung Min Park

More Repositories

1

robustness

A library for experimenting with, training and evaluating neural networks, with a focus on adversarial robustness.
Jupyter Notebook
905
star
2

mnist_challenge

A challenge to explore adversarial robustness of neural networks on MNIST.
Python
720
star
3

cifar10_challenge

A challenge to explore adversarial robustness of neural networks on CIFAR10.
Python
488
star
4

photoguard

Raising the Cost of Malicious AI-Powered Image Editing
Jupyter Notebook
419
star
5

constructed-datasets

Datasets for the paper "Adversarial Examples are not Bugs, They Are Features"
178
star
6

robust_representations

Code for "Learning Perceptually-Aligned Representations via Adversarial Robustness"
Jupyter Notebook
158
star
7

backgrounds_challenge

Python
134
star
8

robustness_applications

Notebooks for reproducing the paper "Computer Vision with a Single (Robust) Classifier"
Jupyter Notebook
125
star
9

implementation-matters

Python
104
star
10

EditingClassifiers

Python
95
star
11

robust-features-code

Code for "Robustness May Be at Odds with Accuracy"
Jupyter Notebook
91
star
12

datamodels-data

Data for "Datamodels: Predicting Predictions with Training Data"
Python
64
star
13

blackbox-bandits

Code for "Prior Convictions: Black-Box Adversarial Attacks with Bandits and Priors"
Python
61
star
14

BREEDS-Benchmarks

Jupyter Notebook
50
star
15

cox

A lightweight experimental logging library
Python
50
star
16

adversarial_spatial

Investigating the robustness of state-of-the-art CNN architectures to simple spatial transformations.
Python
49
star
17

modeldiff

ModelDiff: A Framework for Comparing Learning Algorithms
Jupyter Notebook
44
star
18

failure-directions

Distilling Model Failures as Directions in Latent Space
Jupyter Notebook
42
star
19

smoothed-vit

Certified Patch Robustness via Smoothed Vision Transformers
Python
41
star
20

label-consistent-backdoor-code

Code for "Label-Consistent Backdoor Attacks"
Python
40
star
21

dataset-interfaces

Dataset Interfaces: Diagnosing Model Failures Using Controllable Counterfactual Generation
Jupyter Notebook
39
star
22

DebuggableDeepNetworks

Jupyter Notebook
37
star
23

data-transfer

Python
31
star
24

ImageNetMultiLabel

Fine-grained ImageNet annotations
Jupyter Notebook
28
star
25

relu_stable

Python
26
star
26

spatial-pytorch

Codebase for "Exploring the Landscape of Spatial Robustness" (ICML'19, https://arxiv.org/abs/1712.02779).
Jupyter Notebook
26
star
27

dataset-replication-analysis

Jupyter Notebook
25
star
28

backdoor_data_poisoning

Python
25
star
29

glm_saga

Minimal, standalone library for solving GLMs in PyTorch
Python
23
star
30

AdvEx_Tutorial

Jupyter Notebook
14
star
31

rethinking-backdoor-attacks

Python
14
star
32

bias-transfer

Python
13
star
33

robustness_lib

Python
12
star
34

journey-TRAK

Code for the paper "The Journey, Not the Destination: How Data Guides Diffusion Models"
Python
12
star
35

datamodels

Python
12
star
36

rla

Residue Level Alignment
Python
12
star
37

copriors

Combining Diverse Feature Priors
Python
8
star
38

missingness

Code for our ICLR 2022 paper "Missingness Bias in Model Debugging"
Jupyter Notebook
5
star
39

fast_l1

Jupyter Notebook
3
star
40

pytorch-lightning-imagenet

Python
3
star
41

post--adv-discussion

HTML
2
star
42

AIaaS_Supply_Chains

Dataset and overview
2
star
43

pytorch-example-imagenet

Python
1
star
44

mnist_challenge_models

1
star
45

robust_model_colab

JavaScript
1
star