• Stars
    star
    4,781
  • Rank 8,815 (Top 0.2 %)
  • Language
    Python
  • License
    BSD 3-Clause "New...
  • Created about 5 years ago
  • Updated 3 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Model interpretability and understanding for PyTorch

Captum Logo


GitHub - License Conda PyPI CircleCI Conda - Platform Conda (channel only) Conda Recipe Docs - GitHub.io

Captum is a model interpretability and understanding library for PyTorch. Captum means comprehension in Latin and contains general purpose implementations of integrated gradients, saliency maps, smoothgrad, vargrad and others for PyTorch models. It has quick integration for models built with domain-specific libraries such as torchvision, torchtext, and others.

Captum is currently in beta and under active development!

About Captum

With the increase in model complexity and the resulting lack of transparency, model interpretability methods have become increasingly important. Model understanding is both an active area of research as well as an area of focus for practical applications across industries using machine learning. Captum provides state-of-the-art algorithms, including Integrated Gradients, to provide researchers and developers with an easy way to understand which features are contributing to a model’s output.

For model developers, Captum can be used to improve and troubleshoot models by facilitating the identification of different features that contribute to a model’s output in order to design better models and troubleshoot unexpected model outputs.

Captum helps ML researchers more easily implement interpretability algorithms that can interact with PyTorch models. Captum also allows researchers to quickly benchmark their work against other existing algorithms available in the library.

Overview of Attribution Algorithms

Target Audience

The primary audiences for Captum are model developers who are looking to improve their models and understand which features are important and interpretability researchers focused on identifying algorithms that can better interpret many types of models.

Captum can also be used by application engineers who are using trained models in production. Captum provides easier troubleshooting through improved model interpretability, and the potential for delivering better explanations to end users on why they’re seeing a specific piece of content, such as a movie recommendation.

Installation

Installation Requirements

  • Python >= 3.6
  • PyTorch >= 1.6
Installing the latest release

The latest release of Captum is easily installed either via Anaconda (recommended) or via pip.

with conda

You can install captum from any of the following supported conda channels:

  • channel: pytorch

    conda install captum -c pytorch
  • channel: conda-forge

    conda install captum -c conda-forge

With pip

pip install captum

Manual / Dev install

If you'd like to try our bleeding edge features (and don't mind potentially running into the occasional bug here or there), you can install the latest master directly from GitHub. For a basic install, run:

git clone https://github.com/pytorch/captum.git
cd captum
pip install -e .

To customize the installation, you can also run the following variants of the above:

  • pip install -e .[insights]: Also installs all packages necessary for running Captum Insights.
  • pip install -e .[dev]: Also installs all tools necessary for development (testing, linting, docs building; see Contributing below).
  • pip install -e .[tutorials]: Also installs all packages necessary for running the tutorial notebooks.

To execute unit tests from a manual install, run:

# running a single unit test
python -m unittest -v tests.attr.test_saliency
# running all unit tests
pytest -ra

Getting Started

Captum helps you interpret and understand predictions of PyTorch models by exploring features that contribute to a prediction the model makes. It also helps understand which neurons and layers are important for model predictions.

Let's apply some of those algorithms to a toy model we have created for demonstration purposes. For simplicity, we will use the following architecture, but users are welcome to use any PyTorch model of their choice.

import numpy as np

import torch
import torch.nn as nn

from captum.attr import (
    GradientShap,
    DeepLift,
    DeepLiftShap,
    IntegratedGradients,
    LayerConductance,
    NeuronConductance,
    NoiseTunnel,
)

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(3, 3)
        self.relu = nn.ReLU()
        self.lin2 = nn.Linear(3, 2)

        # initialize weights and biases
        self.lin1.weight = nn.Parameter(torch.arange(-4.0, 5.0).view(3, 3))
        self.lin1.bias = nn.Parameter(torch.zeros(1,3))
        self.lin2.weight = nn.Parameter(torch.arange(-3.0, 3.0).view(2, 3))
        self.lin2.bias = nn.Parameter(torch.ones(1,2))

    def forward(self, input):
        return self.lin2(self.relu(self.lin1(input)))

Let's create an instance of our model and set it to eval mode.

model = ToyModel()
model.eval()

Next, we need to define simple input and baseline tensors. Baselines belong to the input space and often carry no predictive signal. Zero tensor can serve as a baseline for many tasks. Some interpretability algorithms such as IntegratedGradients, Deeplift and GradientShap are designed to attribute the change between the input and baseline to a predictive class or a value that the neural network outputs.

We will apply model interpretability algorithms on the network mentioned above in order to understand the importance of individual neurons/layers and the parts of the input that play an important role in the final prediction.

To make computations deterministic, let's fix random seeds.

torch.manual_seed(123)
np.random.seed(123)

Let's define our input and baseline tensors. Baselines are used in some interpretability algorithms such as IntegratedGradients, DeepLift, GradientShap, NeuronConductance, LayerConductance, InternalInfluence and NeuronIntegratedGradients.

input = torch.rand(2, 3)
baseline = torch.zeros(2, 3)

Next we will use IntegratedGradients algorithms to assign attribution scores to each input feature with respect to the first target output.

ig = IntegratedGradients(model)
attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True)
print('IG Attributions:', attributions)
print('Convergence Delta:', delta)

Output:

IG Attributions: tensor([[-0.5922, -1.5497, -1.0067],
                         [ 0.0000, -0.2219, -5.1991]])
Convergence Delta: tensor([2.3842e-07, -4.7684e-07])

The algorithm outputs an attribution score for each input element and a convergence delta. The lower the absolute value of the convergence delta the better is the approximation. If we choose not to return delta, we can simply not provide the return_convergence_delta input argument. The absolute value of the returned deltas can be interpreted as an approximation error for each input sample. It can also serve as a proxy of how accurate the integral approximation for given inputs and baselines is. If the approximation error is large, we can try a larger number of integral approximation steps by setting n_steps to a larger value. Not all algorithms return approximation error. Those which do, though, compute it based on the completeness property of the algorithms.

Positive attribution score means that the input in that particular position positively contributed to the final prediction and negative means the opposite. The magnitude of the attribution score signifies the strength of the contribution. Zero attribution score means no contribution from that particular feature.

Similarly, we can apply GradientShap, DeepLift and other attribution algorithms to the model.

GradientShap first chooses a random baseline from baselines' distribution, then adds gaussian noise with std=0.09 to each input example n_samples times. Afterwards, it chooses a random point between each example-baseline pair and computes the gradients with respect to target class (in this case target=0). Resulting attribution is the mean of gradients * (inputs - baselines)

gs = GradientShap(model)

# We define a distribution of baselines and draw `n_samples` from that
# distribution in order to estimate the expectations of gradients across all baselines
baseline_dist = torch.randn(10, 3) * 0.001
attributions, delta = gs.attribute(input, stdevs=0.09, n_samples=4, baselines=baseline_dist,
                                   target=0, return_convergence_delta=True)
print('GradientShap Attributions:', attributions)
print('Convergence Delta:', delta)

Output

GradientShap Attributions: tensor([[-0.1542, -1.6229, -1.5835],
                                   [-0.3916, -0.2836, -4.6851]])
Convergence Delta: tensor([ 0.0000, -0.0005, -0.0029, -0.0084, -0.0087, -0.0405,  0.0000, -0.0084])

Deltas are computed for each n_samples * input.shape[0] example. The user can, for instance, average them:

deltas_per_example = torch.mean(delta.reshape(input.shape[0], -1), dim=1)

in order to get per example average delta.

Below is an example of how we can apply DeepLift and DeepLiftShap on the ToyModel described above. The current implementation of DeepLift supports only the Rescale rule. For more details on alternative implementations, please see the DeepLift paper.

dl = DeepLift(model)
attributions, delta = dl.attribute(input, baseline, target=0, return_convergence_delta=True)
print('DeepLift Attributions:', attributions)
print('Convergence Delta:', delta)

Output

DeepLift Attributions: tensor([[-0.5922, -1.5497, -1.0067],
                               [ 0.0000, -0.2219, -5.1991])
Convergence Delta: tensor([0., 0.])

DeepLift assigns similar attribution scores as IntegratedGradients to inputs, however it has lower execution time. Another important thing to remember about DeepLift is that it currently doesn't support all non-linear activation types. For more details on limitations of the current implementation, please see the DeepLift paper.

Similar to integrated gradients, DeepLift returns a convergence delta score per input example. The approximation error is then the absolute value of the convergence deltas and can serve as a proxy of how accurate the algorithm's approximation is.

Now let's look into DeepLiftShap. Similar to GradientShap, DeepLiftShap uses baseline distribution. In the example below, we use the same baseline distribution as for GradientShap.

dl = DeepLiftShap(model)
attributions, delta = dl.attribute(input, baseline_dist, target=0, return_convergence_delta=True)
print('DeepLiftSHAP Attributions:', attributions)
print('Convergence Delta:', delta)

Output

DeepLiftShap Attributions: tensor([[-5.9169e-01, -1.5491e+00, -1.0076e+00],
                                   [-4.7101e-03, -2.2300e-01, -5.1926e+00]], grad_fn=<MeanBackward1>)
Convergence Delta: tensor([-4.6120e-03, -1.6267e-03, -5.1045e-04, -1.4184e-03, -6.8886e-03,
                           -2.2224e-02,  0.0000e+00, -2.8790e-02, -4.1285e-03, -2.7295e-02,
                           -3.2349e-03, -1.6265e-03, -4.7684e-07, -1.4191e-03, -6.8889e-03,
                           -2.2224e-02,  0.0000e+00, -2.4792e-02, -4.1289e-03, -2.7296e-02])

DeepLiftShap uses DeepLift to compute attribution score for each input-baseline pair and averages it for each input across all baselines.

It computes deltas for each input example-baseline pair, thus resulting to input.shape[0] * baseline.shape[0] delta values.

Similar to GradientShap in order to compute example-based deltas we can average them per example:

deltas_per_example = torch.mean(delta.reshape(input.shape[0], -1), dim=1)

In order to smooth and improve the quality of the attributions we can run IntegratedGradients and other attribution methods through a NoiseTunnel. NoiseTunnel allows us to use SmoothGrad, SmoothGrad_Sq and VarGrad techniques to smoothen the attributions by aggregating them for multiple noisy samples that were generated by adding gaussian noise.

Here is an example of how we can use NoiseTunnel with IntegratedGradients.

ig = IntegratedGradients(model)
nt = NoiseTunnel(ig)
attributions, delta = nt.attribute(input, nt_type='smoothgrad', stdevs=0.02, nt_samples=4,
      baselines=baseline, target=0, return_convergence_delta=True)
print('IG + SmoothGrad Attributions:', attributions)
print('Convergence Delta:', delta)

Output

IG + SmoothGrad Attributions: tensor([[-0.4574, -1.5493, -1.0893],
                                      [ 0.0000, -0.2647, -5.1619]])
Convergence Delta: tensor([ 0.0000e+00,  2.3842e-07,  0.0000e+00, -2.3842e-07,  0.0000e+00,
        -4.7684e-07,  0.0000e+00, -4.7684e-07])

The number of elements in the delta tensor is equal to: nt_samples * input.shape[0] In order to get an example-wise delta, we can, for example, average them:

deltas_per_example = torch.mean(delta.reshape(input.shape[0], -1), dim=1)

Let's look into the internals of our network and understand which layers and neurons are important for the predictions.

We will start with the NeuronConductance. NeuronConductance helps us to identify input features that are important for a particular neuron in a given layer. It decomposes the computation of integrated gradients via the chain rule by defining the importance of a neuron as path integral of the derivative of the output with respect to the neuron times the derivatives of the neuron with respect to the inputs of the model.

In this case, we choose to analyze the first neuron in the linear layer.

nc = NeuronConductance(model, model.lin1)
attributions = nc.attribute(input, neuron_selector=1, target=0)
print('Neuron Attributions:', attributions)

Output

Neuron Attributions: tensor([[ 0.0000,  0.0000,  0.0000],
                             [ 1.3358,  0.0000, -1.6811]])

Layer conductance shows the importance of neurons for a layer and given input. It is an extension of path integrated gradients for hidden layers and holds the completeness property as well.

It doesn't attribute the contribution scores to the input features but shows the importance of each neuron in the selected layer.

lc = LayerConductance(model, model.lin1)
attributions, delta = lc.attribute(input, baselines=baseline, target=0, return_convergence_delta=True)
print('Layer Attributions:', attributions)
print('Convergence Delta:', delta)

Outputs

Layer Attributions: tensor([[ 0.0000,  0.0000, -3.0856],
                            [ 0.0000, -0.3488, -4.9638]], grad_fn=<SumBackward1>)
Convergence Delta: tensor([0.0630, 0.1084])

Similar to other attribution algorithms that return convergence delta, LayerConductance returns the deltas for each example. The approximation error is then the absolute value of the convergence deltas and can serve as a proxy of how accurate integral approximation for given inputs and baselines is.

More details on the list of supported algorithms and how to apply Captum on different types of models can be found in our tutorials.

Captum Insights

Captum provides a web interface called Insights for easy visualization and access to a number of our interpretability algorithms.

To analyze a sample model on CIFAR10 via Captum Insights run

python -m captum.insights.example

and navigate to the URL specified in the output.

Captum Insights Screenshot

To build Insights you will need Node >= 8.x and Yarn >= 1.5.

To build and launch from a checkout in a conda environment run

conda install -c conda-forge yarn
BUILD_INSIGHTS=1 python setup.py develop
python captum/insights/example.py

Captum Insights Jupyter Widget

Captum Insights also has a Jupyter widget providing the same user interface as the web app. To install and enable the widget, run

jupyter nbextension install --py --symlink --sys-prefix captum.insights.attr_vis.widget
jupyter nbextension enable captum.insights.attr_vis.widget --py --sys-prefix

To build the widget from a checkout in a conda environment run

conda install -c conda-forge yarn
BUILD_INSIGHTS=1 python setup.py develop

FAQ

If you have questions about using Captum methods, please check this FAQ, which addresses many common issues.

Contributing

See the CONTRIBUTING file for how to help out.

Talks and Papers

NeurIPS 2019: The slides of our presentation can be found here

KDD 2020: The slides of our presentation from KDD 2020 tutorial can be found here. You can watch the recorded talk here

GTC 2020: Opening Up the Black Box: Model Understanding with Captum and PyTorch. You can watch the recorded talk here

XAI Summit 2020: Using Captum and Fiddler to Improve Model Understanding with Explainable AI. You can watch the recorded talk here

PyTorch Developer Day 2020 Model Interpretability. You can watch the recorded talk here

NAACL 2021 Tutorial on Fine-grained Interpretation and Causation Analysis in Deep NLP Models. You can watch the recorded talk here

ICLR 2021 workshop on Responsible AI:

  • Paper on the Captum Library
  • Paper on Invesitgating Sanity Checks for Saliency Maps

References of Algorithms

More details about the above mentioned attribution algorithms and their pros and cons can be found on our web-site.

License

Captum is BSD licensed, as found in the LICENSE file.

More Repositories

1

pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
Python
83,553
star
2

examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
Python
22,311
star
3

vision

Datasets, Transforms and Models specific to Computer Vision
Python
15,925
star
4

tutorials

PyTorch tutorials.
Jupyter Notebook
8,075
star
5

ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
Python
4,507
star
6

serve

Serve, optimize and scale PyTorch models in production
Java
4,190
star
7

torchtune

PyTorch native finetuning library
Python
4,163
star
8

text

Models, data loaders and abstractions for language processing, powered by PyTorch
Python
3,490
star
9

ELF

ELF: a platform for game research with AlphaGoZero/AlphaZero reimplementation
C++
3,364
star
10

glow

Compiler for Neural Network hardware accelerators
C++
3,197
star
11

botorch

Bayesian optimization in PyTorch
Jupyter Notebook
3,043
star
12

torchchat

Run PyTorch LLMs locally on servers, desktop and mobile
Python
3,040
star
13

TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
Python
2,565
star
14

audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
Python
2,471
star
15

xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
C++
2,469
star
16

rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
Python
2,241
star
17

torchtitan

A native PyTorch Library for large model training
Python
2,130
star
18

executorch

On-device AI across mobile, embedded and edge for PyTorch
C++
1,954
star
19

torchrec

Pytorch domain library for recommendation systems
Python
1,852
star
20

opacus

Training PyTorch models with differential privacy
Jupyter Notebook
1,666
star
21

tnt

A lightweight library for PyTorch training tools and utilities
Python
1,650
star
22

QNNPACK

Quantized Neural Network PACKage - mobile-optimized implementation of quantized neural network operators
C
1,519
star
23

android-demo-app

PyTorch android examples of usage in applications
Java
1,460
star
24

functorch

functorch is JAX-like composable function transforms for PyTorch.
Jupyter Notebook
1,388
star
25

hub

Submission to https://pytorch.org/hub/
Python
1,384
star
26

FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/
C++
1,156
star
27

data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
Python
1,112
star
28

cpuinfo

CPU INFOrmation library (x86/x86-64/ARM/ARM64, Linux/Windows/Android/macOS/iOS)
C
989
star
29

torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
Python
989
star
30

extension-cpp

C++ extensions in PyTorch
Python
980
star
31

benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
Python
841
star
32

translate

Translate - a PyTorch Language Library
Python
820
star
33

tensordict

TensorDict is a pytorch dedicated tensor container.
Python
816
star
34

elastic

PyTorch elastic training
Python
728
star
35

PiPPy

Pipeline Parallelism for PyTorch
Python
698
star
36

kineto

A CPU+GPU Profiling library that provides access to timeline traces and hardware performance counters.
HTML
682
star
37

torcharrow

High performance model preprocessing library on PyTorch
Python
641
star
38

ao

PyTorch native quantization and sparsity for training and inference
Python
630
star
39

ios-demo-app

PyTorch iOS examples
Swift
595
star
40

tvm

TVM integration into PyTorch
C++
451
star
41

contrib

Implementations of ideas from recent papers
Python
390
star
42

ort

Accelerate PyTorch models with ONNX Runtime
Python
353
star
43

builder

Continuous builder and binary build scripts for pytorch
Shell
325
star
44

torchx

TorchX is a universal job launcher for PyTorch applications. TorchX is designed to have fast iteration time for training/research and support for E2E production ML pipelines when you're ready.
Python
319
star
45

accimage

high performance image loading and augmenting routines mimicking PIL.Image interface
C
317
star
46

extension-ffi

Examples of C extensions for PyTorch
Python
257
star
47

nestedtensor

[Prototype] Tools for the concurrent manipulation of variably sized Tensors.
Jupyter Notebook
252
star
48

tensorpipe

A tensor-aware point-to-point communication primitive for machine learning
C++
247
star
49

pytorch.github.io

The website for PyTorch
HTML
222
star
50

torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
Python
210
star
51

cppdocs

PyTorch C++ API Documentation
HTML
206
star
52

workshops

This is a repository for all workshop related materials.
Jupyter Notebook
204
star
53

hydra-torch

Configuration classes enabling type-safe PyTorch configuration for Hydra apps
Python
199
star
54

multipy

torch::deploy (multipy for non-torch uses) is a system that lets you get around the GIL problem by running multiple Python interpreters in a single C++ process.
C++
169
star
55

torchsnapshot

A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.
Python
142
star
56

java-demo

Jupyter Notebook
126
star
57

rfcs

PyTorch RFCs (experimental)
120
star
58

torchdistx

Torch Distributed Experimental
Python
115
star
59

extension-script

Example repository for custom C++/CUDA operators for TorchScript
Python
112
star
60

csprng

Cryptographically secure pseudorandom number generators for PyTorch
Batchfile
105
star
61

pytorch_sphinx_theme

PyTorch Sphinx Theme
CSS
94
star
62

test-infra

This repository hosts code that supports the testing infrastructure for the main PyTorch repo. For example, this repo hosts the logic to track disabled tests and slow tests, as well as our continuation integration jobs HUD/dashboard.
TypeScript
78
star
63

expecttest

Python
71
star
64

torchcodec

PyTorch video decoding
Python
46
star
65

maskedtensor

MaskedTensors for PyTorch
Python
38
star
66

add-annotations-github-action

A GitHub action to run clang-tidy and annotate failures
JavaScript
13
star
67

ci-hud

HUD for CI activity on `pytorch/pytorch`, provides a top level view for jobs to easily discern regressions
JavaScript
11
star
68

probot

PyTorch GitHub bot written in probot
TypeScript
11
star
69

ossci-job-dsl

Jenkins job definitions for OSSCI
Groovy
10
star
70

pytorch-integration-testing

Testing downstream libraries using pytorch release candidates
Makefile
6
star
71

docs

This repository is automatically generated to contain the website source for the PyTorch documentation at https//pytorch.org/docs.
HTML
4
star
72

torchhub_testing

Repo to test torchhub. Nothing to see here.
4
star
73

dr-ci

Diagnose and remediate CI jobs
Haskell
2
star
74

pytorch-ci-dockerfiles

Scripts for generating docker images for PyTorch CI
2
star
75

labeler-github-action

GitHub action for labeling issues and pull requests based on conditions
TypeScript
1
star