• Stars
    star
    336
  • Rank 125,564 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created over 7 years ago
  • Updated almost 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

The paper: https://arxiv.org/abs/1704.03296

What makes the deep learning network think the image label is 'pug, pug-dog' and 'tabby, tabby cat':

Dog Cat

A perturbation of the dog that caused the dog category score to vanish:

Perturbed

What makes the deep learning network think the image label is 'flute, transverse flute':

Flute


Usage: python explain.py <path_to_image>

This is a PyTorch impelentation of

"Interpretable Explanations of Black Boxes by Meaningful Perturbation. Ruth Fong, Andrea Vedaldi" with some deviations.

This uses VGG19 from torchvision. It will be downloaded when used for the first time.

This learns a mask of pixels that explain the result of a black box. The mask is learned by posing an optimization problem and solving directly for the mask values.

This is different than other visualization techniques like Grad-CAM that use heuristics like high positive gradient values as an indication of relevance to the network score.

In our case the black box is the VGG19 model, but this can use any differentiable model.


How it works

Equation

Taken from the paper https://arxiv.org/abs/1704.03296

The goal is to solve for a mask that explains why did the network output a score for a certain category.

We create a low resolution (28x28) mask, and use it to perturb the input image to a deep learning network.

The perturbation combines a blurred version of the image, the regular image, and the up-sampled mask.

Wherever the mask contains low values, the input image will become more blurry.

We want to optimize for the next properties:

  1. When using the mask to blend the input image and it's blurred versions, the score of the target category should drop significantly. The evidence of the category should be removed!
  2. The mask should be sparse. Ideally the mask should be the minimal possible mask to drop the category score. This translates to a L1(1 - mask) term in the cost function.
  3. The mask should be smooth. This translates to a total variation regularization in the cost function.
  4. The mask shouldn't over-fit the network. Since the network activations might contain a lot of noise, it can be easy for the mask to just learn random values that cause the score to drop without being visually coherent. In addition to the other terms, this translates to solving for a lower resolution 28x28 mask.

Deviations from the paper

The paper uses a gaussian kernel with a sigma that is modulated by the value of the mask. This is computational costly to compute since the mask values are updated during the iterations, meaning we need a different kernel for every mask pixel for every iteration.

Initially I tried approximating this by first filtering the image with a filter bank of varying gaussian kernels. Then during optimization, the input image pixel would use the quantized mask value to select an appropriate filter bank output pixel (high mask value -> lower channel).

This was done using the PyTorch variable gather/select_index functions. But it turns out that the gather and select_index functions in PyTorch are not differentiable by the indexes.

Instead, we just compute a perturbed image once, and then blend the image and the perturbed image using:

input_image = (1 - mask) * image + mask * perturbed_image

And it works well in practice.

The perturbed image here is the average of the gaussian and median blurred image, but this can really be changed to many other combinations (try it out and find something better!).

Also now gaussian noise with a sigma of 0.2 is added to the preprocssed image at each iteration, inspired by google's SmoothGradient.

More Repositories

1

pytorch-grad-cam

Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.
Python
10,410
star
2

keras-dcgan

Keras implementation of Deep Convolutional Generative Adversarial Networks
Python
976
star
3

pytorch-pruning

PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference
Python
873
star
4

vit-explain

Explainability for Vision Transformers
Python
791
star
5

keras-grad-cam

An implementation of Grad-CAM with keras
Python
656
star
6

keras-cam

Keras implementation of class activation mapping
Python
335
star
7

pytorch-tensor-decompositions

PyTorch implementation of [1412.6553] and [1511.06530] tensor decomposition methods for convolutional layers.
Python
275
star
8

pytorch-zssr

PyTorch implementation of 1712.06087 "Zero-Shot" Super-Resolution using Deep Internal Learning
Python
199
star
9

pyfishervector

Python implementation for Image Classification based on GMM dictionaries and fisher vectors.
Python
137
star
10

keras-filter-visualization

Visualizing filters by finding images that maximize their outputs
Python
136
star
11

confidenceinterval

The long missing library for python confidence intervals
Python
125
star
12

keras-steering-angle-visualizations

Visualizations for understanding the regressed wheel steering angle for self driving cars
Python
61
star
13

dlib_facedetector_pytorch

Porting of Dlib's mmod deep learning face detector model to pytorch, and examples of using it for webcam detection, and face haluciniations
Python
32
star
14

saliency-from-backproj

Saliency map generated by back projecting the image histogram on itself, and refinement with Grabcut.
Python
28
star
15

BagOfVisualWords

A simple Matlab implementation of Bag Of Words with SIFT keypoints and HoG descriptors, using VLFeat.
MATLAB
25
star
16

Ambrosio-Tortorelli-Minimizer

Python implementation of minimizing the mumford-shah functional for piecewise smooth image approximation.
Python
25
star
17

CaffeFeaturesExample

Sample code for classifying images into two categories using Caffe features + SVM.
Python
10
star
18

jacobgil.github.io

Personal blog
HTML
9
star
19

TensorFlowFeaturesExample

Extracting features from a tensor flow model for transfer learning
Python
4
star
20

jacobgil

github profile readme
1
star
21

pytorch-gradcam-book

A jupyter-book documentation for the pytorch-gradcam package
1
star