• Stars
    star
    927
  • Rank 47,325 (Top 1.0 %)
  • Language
    Jupyter Notebook
  • License
    Apache License 2.0
  • Created almost 7 years ago
  • Updated about 1 month ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Framework-agnostic implementation for state-of-the-art saliency methods (XRAI, BlurIG, SmoothGrad, and more).

Saliency Library

Updates

🔴   Now framework-agnostic! (Example core notebook)  🔴

🔗   For further explanation of the methods and more examples of the resulting maps, see our Github Pages website  🔗

If upgrading from an older version, update old imports to import saliency.tf1 as saliency. We provide wrappers to make the framework-agnostic version compatible with TF1 models. (Example TF1 notebook)

🔴   Added Performance Information Curve (PIC) - a human independent metric for evaluating the quality of saliency methods. (Example notebook)  🔴

Saliency Methods

This repository contains code for the following saliency techniques:

*Developed by PAIR.

This list is by no means comprehensive. We are accepting pull requests to add new methods!

Evaluation of Saliency Methods

The repository provides an implementation of Performance Information Curve (PIC) - a human independent metric for evaluating the quality of saliency methods (paper, poster, code, notebook).

Download

# To install the core subpackage:
pip install saliency

# To install core and tf1 subpackages:
pip install saliency[tf1]

or for the development version:

git clone https://github.com/pair-code/saliency
cd saliency

Usage

The saliency library has two subpackages:

  • core uses a generic call_model_function which can be used with any ML framework.
  • tf1 accepts input/output tensors directly, and sets up the necessary graph operations for each method.

Core

Each saliency mask class extends from the CoreSaliency base class. This class contains the following methods:

  • GetMask(x_value, call_model_function, call_model_args=None): Returns a mask of the shape of non-batched x_value given by the saliency technique.
  • GetSmoothedMask(x_value, call_model_function, call_model_args=None, stdev_spread=.15, nsamples=25, magnitude=True): Returns a mask smoothed of the shape of non-batched x_value with the SmoothGrad technique.

The visualization module contains two methods for saliency visualization:

  • VisualizeImageGrayscale(image_3d, percentile): Marginalizes across the absolute value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between 0 to 1.
  • VisualizeImageDiverging(image_3d, percentile): Marginalizes across the value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between -1 to 1 where zero remains unchanged.

If the sign of the value given by the saliency mask is not important, then use VisualizeImageGrayscale, otherwise use VisualizeImageDiverging. See the SmoothGrad paper for more details on which visualization method to use.

call_model_function

call_model_function is how we pass inputs to a given model and receive the outputs necessary to compute saliency masks. The description of this method and expected output format is in the CoreSaliency description, as well as separately for each method.

Examples

This example iPython notebook showing these techniques is a good starting place.

Here is a condensed example of using IG+SmoothGrad with TensorFlow 2:

import saliency.core as saliency
import tensorflow as tf

...

# call_model_function construction here.
def call_model_function(x_value_batched, call_model_args, expected_keys):
	tape = tf.GradientTape()
	grads = np.array(tape.gradient(output_layer, images))
	return {saliency.INPUT_OUTPUT_GRADIENTS: grads}

...

# Load data.
image = GetImagePNG(...)

# Compute IG+SmoothGrad.
ig_saliency = saliency.IntegratedGradients()
smoothgrad_ig = ig_saliency.GetSmoothedMask(image, 
											call_model_function, 
                                            call_model_args=None)

# Compute a 2D tensor for visualization.
grayscale_visualization = saliency.VisualizeImageGrayscale(
    smoothgrad_ig)

TF1

Each saliency mask class extends from the TF1Saliency base class. This class contains the following methods:

  • __init__(graph, session, y, x): Constructor of the SaliencyMask. This can modify the graph, or sometimes create a new graph. Often this will add nodes to the graph, so this shouldn't be called continuously. y is the output tensor to compute saliency masks with respect to, x is the input tensor with the outer most dimension being batch size.
  • GetMask(x_value, feed_dict): Returns a mask of the shape of non-batched x_value given by the saliency technique.
  • GetSmoothedMask(x_value, feed_dict): Returns a mask smoothed of the shape of non-batched x_value with the SmoothGrad technique.

The visualization module contains two visualization methods:

  • VisualizeImageGrayscale(image_3d, percentile): Marginalizes across the absolute value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between 0 to 1.
  • VisualizeImageDiverging(image_3d, percentile): Marginalizes across the value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between -1 to 1 where zero remains unchanged.

If the sign of the value given by the saliency mask is not important, then use VisualizeImageGrayscale, otherwise use VisualizeImageDiverging. See the SmoothGrad paper for more details on which visualization method to use.

Examples

This example iPython notebook shows these techniques is a good starting place.

Another example of using GuidedBackprop with SmoothGrad from TensorFlow:

from saliency.tf1 import GuidedBackprop
from saliency.tf1 import VisualizeImageGrayscale
import tensorflow.compat.v1 as tf

...
# Tensorflow graph construction here.
y = logits[5]
x = tf.placeholder(...)
...

# Compute guided backprop.
# NOTE: This creates another graph that gets cached, try to avoid creating many
# of these.
guided_backprop_saliency = GuidedBackprop(graph, session, y, x)

...
# Load data.
image = GetImagePNG(...)
...

smoothgrad_guided_backprop =
    guided_backprop_saliency.GetMask(image, feed_dict={...})

# Compute a 2D tensor for visualization.
grayscale_visualization = visualization.VisualizeImageGrayscale(
    smoothgrad_guided_backprop)

Conclusion/Disclaimer

If you have any questions or suggestions for improvements to this library, please contact the owners of the PAIR-code/saliency repository.

This is not an official Google product.

More Repositories

1

facets

Visualizations for machine learning datasets
Jupyter Notebook
7,308
star
2

lit

The Learning Interpretability Tool: Interactively analyze ML models to understand their behavior in an extensible and framework agnostic interface.
TypeScript
3,386
star
3

what-if-tool

Source code/webpage/demos for the What-If Tool
HTML
881
star
4

umap-js

JavaScript implementation of UMAP
JavaScript
344
star
5

knowyourdata

A tool to help researchers and product teams understand datasets with the goal of improving data quality, and mitigating fairness and bias issues.
CSS
273
star
6

wordcraft

✨✍️ Wordcraft is an AI-powered text editor with an emphasis on short story writing
TypeScript
208
star
7

federated-learning

Federated learning experiment using TensorFlow.js
TypeScript
159
star
8

scatter-gl

Interactive 3D / 2D webgl-accelerated scatter plot point renderer
TypeScript
157
star
9

datacardsplaybook

The Data Cards Playbook helps dataset producers and publishers adopt a people-centered approach to transparency in dataset documentation.
TypeScript
157
star
10

understanding-umap

Understanding the theory behind UMAP
JavaScript
150
star
11

interpretability

PAIR.withgoogle.com and friend's work on interpretability methods
JavaScript
109
star
12

ai-explorables

https://pair.withgoogle.com/explorables/
Jupyter Notebook
51
star
13

cococo

𝄡 Collaborative Convolutional Counterpoint
TypeScript
45
star
14

cam-scroller

Cam Scroller is an open-source Chrome extension that uses your webcam and deeplearn.js to enable scrolling through webpages using custom gestures that you define.
JavaScript
33
star
15

font-explorer

Font latent space explorer using tensorflow.js
Vue
31
star
16

clinical-vis

A javascript medical record visualization (https://arxiv.org/abs/1810.05798)
HTML
25
star
17

megaplot

TypeScript
19
star
18

pair-code.github.io

HTML
17
star
19

depth-maps-art-and-illusions

TypeScript
16
star
20

thehardway

Supplementary code repository to accompany Tic-Tac-Toe the Hard Way podcast
JavaScript
11
star
21

covid19_symptom_dataset

JavaScript
11
star
22

recommendation-rudders

TypeScript
10
star
23

jax-recommenders

Python
8
star
24

farsight

In situ interactive widgets for responsible AI 🌱
TypeScript
7
star
25

book-viz

Visualizing multilevel structure in books with sentence embeddings.
Jupyter Notebook
6
star
26

waterfall-of-meaning

TypeScript
4
star
27

tiny-transformers

Jupyter Notebook
4
star
28

deeplearnjs-legacy-loader

Deprecated: Legacy TensorFlow model loader for deeplearn.js
Python
3
star
29

colormap

JavaScript
3
star
30

ml-vis-experiments

Jupyter Notebook
1
star
31

deeplearnjs-docs

TypeScript
1
star
32

auto-histograms

Python
1
star