• Stars
    star
    742
  • Rank 58,986 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 3 years ago
  • Updated about 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Explainability for Vision Transformers

Explainability for Vision Transformers (in PyTorch)

This repository implements methods for explainability in Vision Transformers.

See also https://jacobgil.github.io/deeplearning/vision-transformer-explainability

Currently implemented:

  • Attention Rollout.

  • Gradient Attention Rollout for class specific explainability. This is our attempt to further build upon and improve Attention Rollout.

  • TBD Attention flow is work in progress.

Includes some tweaks and tricks to get it working:

  • Different Attention Head fusion methods,
  • Removing the lowest attentions.

Usage

  • From code
from vit_grad_rollout import VITAttentionGradRollout

model = torch.hub.load('facebookresearch/deit:main', 
'deit_tiny_patch16_224', pretrained=True)
grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')
mask = grad_rollout(input_tensor, category_index=243)
  • From the command line:
python vit_explain.py --image_path <image path> --head_fusion <mean, min or max> --discard_ratio <number between 0 and 1> --category_index <category_index>

If category_index isn't specified, Attention Rollout will be used, otherwise Gradient Attention Rollout will be used.

Notice that by default, this uses the 'Tiny' model from Training data-efficient image transformers & distillation through attention hosted on torch hub.

Where did the Transformer pay attention to in this image?

Image Vanilla Attention Rollout With discard_ratio+max fusion

Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes. Gradient roll out lets us see what locations the network paid attention too, but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).

Where does the Transformer see a Dog (category 243), and a Cat (category 282)?

Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):

Tricks and Tweaks to get this working

Filtering the lowest attentions in every layer

--discard_ratio <value between 0 and 1>

Removes noise by keeping the strongest attentions.

Results for dIfferent values:

Different Attention Head Fusions

The Attention Rollout method suggests taking the average attention accross the attention heads,

but emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better.

--head_fusion <mean, min or max>

Image Mean Fusion Min Fusion

References

Requirements

pip install timm

More Repositories

1

pytorch-grad-cam

Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.
Python
9,521
star
2

keras-dcgan

Keras implementation of Deep Convolutional Generative Adversarial Networks
Python
975
star
3

pytorch-pruning

PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference
Python
866
star
4

keras-grad-cam

An implementation of Grad-CAM with keras
Python
656
star
5

keras-cam

Keras implementation of class activation mapping
Python
335
star
6

pytorch-explain-black-box

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation
Python
335
star
7

pytorch-tensor-decompositions

PyTorch implementation of [1412.6553] and [1511.06530] tensor decomposition methods for convolutional layers.
Python
269
star
8

pytorch-zssr

PyTorch implementation of 1712.06087 "Zero-Shot" Super-Resolution using Deep Internal Learning
Python
198
star
9

pyfishervector

Python implementation for Image Classification based on GMM dictionaries and fisher vectors.
Python
138
star
10

keras-filter-visualization

Visualizing filters by finding images that maximize their outputs
Python
136
star
11

confidenceinterval

The long missing library for python confidence intervals
Python
118
star
12

keras-steering-angle-visualizations

Visualizations for understanding the regressed wheel steering angle for self driving cars
Python
61
star
13

dlib_facedetector_pytorch

Porting of Dlib's mmod deep learning face detector model to pytorch, and examples of using it for webcam detection, and face haluciniations
Python
32
star
14

saliency-from-backproj

Saliency map generated by back projecting the image histogram on itself, and refinement with Grabcut.
Python
28
star
15

BagOfVisualWords

A simple Matlab implementation of Bag Of Words with SIFT keypoints and HoG descriptors, using VLFeat.
MATLAB
25
star
16

Ambrosio-Tortorelli-Minimizer

Python implementation of minimizing the mumford-shah functional for piecewise smooth image approximation.
Python
25
star
17

CaffeFeaturesExample

Sample code for classifying images into two categories using Caffe features + SVM.
Python
10
star
18

jacobgil.github.io

Personal blog
HTML
9
star
19

TensorFlowFeaturesExample

Extracting features from a tensor flow model for transfer learning
Python
4
star
20

jacobgil

github profile readme
1
star
21

pytorch-gradcam-book

A jupyter-book documentation for the pytorch-gradcam package
1
star