• Stars
    star
    951
  • Rank 48,061 (Top 1.0 %)
  • Language
    Jupyter Notebook
  • License
    Apache License 2.0
  • Created over 7 years ago
  • Updated 8 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Framework-agnostic implementation for state-of-the-art saliency methods (XRAI, BlurIG, SmoothGrad, and more).

Saliency Library

Updates

๐Ÿ”ดย ย  Now framework-agnostic! (Example core notebook) ย ๐Ÿ”ด

๐Ÿ”—ย ย  For further explanation of the methods and more examples of the resulting maps, see our Github Pages website ย ๐Ÿ”—

If upgrading from an older version, update old imports to import saliency.tf1 as saliency. We provide wrappers to make the framework-agnostic version compatible with TF1 models. (Example TF1 notebook)

๐Ÿ”ดย ย  Added Performance Information Curve (PIC) - a human independent metric for evaluating the quality of saliency methods. (Example notebook) ย ๐Ÿ”ด

Saliency Methods

This repository contains code for the following saliency techniques:

*Developed by PAIR.

This list is by no means comprehensive. We are accepting pull requests to add new methods!

Evaluation of Saliency Methods

The repository provides an implementation of Performance Information Curve (PIC) - a human independent metric for evaluating the quality of saliency methods (paper, poster, code, notebook).

Download

# To install the core subpackage:
pip install saliency

# To install core and tf1 subpackages:
pip install saliency[tf1]

or for the development version:

git clone https://github.com/pair-code/saliency
cd saliency

Usage

The saliency library has two subpackages:

  • core uses a generic call_model_function which can be used with any ML framework.
  • tf1 accepts input/output tensors directly, and sets up the necessary graph operations for each method.

Core

Each saliency mask class extends from the CoreSaliency base class. This class contains the following methods:

  • GetMask(x_value, call_model_function, call_model_args=None): Returns a mask of the shape of non-batched x_value given by the saliency technique.
  • GetSmoothedMask(x_value, call_model_function, call_model_args=None, stdev_spread=.15, nsamples=25, magnitude=True): Returns a mask smoothed of the shape of non-batched x_value with the SmoothGrad technique.

The visualization module contains two methods for saliency visualization:

  • VisualizeImageGrayscale(image_3d, percentile): Marginalizes across the absolute value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between 0 to 1.
  • VisualizeImageDiverging(image_3d, percentile): Marginalizes across the value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between -1 to 1 where zero remains unchanged.

If the sign of the value given by the saliency mask is not important, then use VisualizeImageGrayscale, otherwise use VisualizeImageDiverging. See the SmoothGrad paper for more details on which visualization method to use.

call_model_function

call_model_function is how we pass inputs to a given model and receive the outputs necessary to compute saliency masks. The description of this method and expected output format is in the CoreSaliency description, as well as separately for each method.

Examples

This example iPython notebook showing these techniques is a good starting place.

Here is a condensed example of using IG+SmoothGrad with TensorFlow 2:

import saliency.core as saliency
import tensorflow as tf

...

# call_model_function construction here.
def call_model_function(x_value_batched, call_model_args, expected_keys):
	tape = tf.GradientTape()
	grads = np.array(tape.gradient(output_layer, images))
	return {saliency.INPUT_OUTPUT_GRADIENTS: grads}

...

# Load data.
image = GetImagePNG(...)

# Compute IG+SmoothGrad.
ig_saliency = saliency.IntegratedGradients()
smoothgrad_ig = ig_saliency.GetSmoothedMask(image, 
											call_model_function, 
                                            call_model_args=None)

# Compute a 2D tensor for visualization.
grayscale_visualization = saliency.VisualizeImageGrayscale(
    smoothgrad_ig)

TF1

Each saliency mask class extends from the TF1Saliency base class. This class contains the following methods:

  • __init__(graph, session, y, x): Constructor of the SaliencyMask. This can modify the graph, or sometimes create a new graph. Often this will add nodes to the graph, so this shouldn't be called continuously. y is the output tensor to compute saliency masks with respect to, x is the input tensor with the outer most dimension being batch size.
  • GetMask(x_value, feed_dict): Returns a mask of the shape of non-batched x_value given by the saliency technique.
  • GetSmoothedMask(x_value, feed_dict): Returns a mask smoothed of the shape of non-batched x_value with the SmoothGrad technique.

The visualization module contains two visualization methods:

  • VisualizeImageGrayscale(image_3d, percentile): Marginalizes across the absolute value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between 0 to 1.
  • VisualizeImageDiverging(image_3d, percentile): Marginalizes across the value of each channel to create a 2D single channel image, and clips the image at the given percentile of the distribution. This method returns a 2D tensor normalized between -1 to 1 where zero remains unchanged.

If the sign of the value given by the saliency mask is not important, then use VisualizeImageGrayscale, otherwise use VisualizeImageDiverging. See the SmoothGrad paper for more details on which visualization method to use.

Examples

This example iPython notebook shows these techniques is a good starting place.

Another example of using GuidedBackprop with SmoothGrad from TensorFlow:

from saliency.tf1 import GuidedBackprop
from saliency.tf1 import VisualizeImageGrayscale
import tensorflow.compat.v1 as tf

...
# Tensorflow graph construction here.
y = logits[5]
x = tf.placeholder(...)
...

# Compute guided backprop.
# NOTE: This creates another graph that gets cached, try to avoid creating many
# of these.
guided_backprop_saliency = GuidedBackprop(graph, session, y, x)

...
# Load data.
image = GetImagePNG(...)
...

smoothgrad_guided_backprop =
    guided_backprop_saliency.GetMask(image, feed_dict={...})

# Compute a 2D tensor for visualization.
grayscale_visualization = visualization.VisualizeImageGrayscale(
    smoothgrad_guided_backprop)

Conclusion/Disclaimer

If you have any questions or suggestions for improvements to this library, please contact the owners of the PAIR-code/saliency repository.

This is not an official Google product.

More Repositories

1

facets

Visualizations for machine learning datasets
Jupyter Notebook
7,345
star
2

lit

The Learning Interpretability Tool: Interactively analyze ML models to understand their behavior in an extensible and framework agnostic interface.
TypeScript
3,482
star
3

what-if-tool

Source code/webpage/demos for the What-If Tool
HTML
909
star
4

umap-js

JavaScript implementation of UMAP
JavaScript
375
star
5

llm-comparator

LLM Comparator is an interactive data visualization tool for evaluating and analyzing LLM responses side-by-side, developed by the PAIR team.
JavaScript
286
star
6

knowyourdata

A tool to help researchers and product teams understand datasets with the goal of improving data quality, and mitigating fairness and bias issues.
CSS
281
star
7

wordcraft

โœจโœ๏ธ Wordcraft is an AI-powered text editor with an emphasis on short story writing
TypeScript
239
star
8

datacardsplaybook

The Data Cards Playbook helps dataset producers and publishers adopt a people-centered approach to transparency in dataset documentation.
TypeScript
170
star
9

scatter-gl

Interactive 3D / 2D webgl-accelerated scatter plot point renderer
TypeScript
168
star
10

understanding-umap

Understanding the theory behind UMAP
JavaScript
164
star
11

federated-learning

Federated learning experiment using TensorFlow.js
TypeScript
160
star
12

interpretability

PAIR.withgoogle.com and friend's work on interpretability methods
JavaScript
147
star
13

ai-explorables

https://pair.withgoogle.com/explorables/
Jupyter Notebook
59
star
14

cococo

๐„ก Collaborative Convolutional Counterpoint
TypeScript
46
star
15

cam-scroller

Cam Scroller is an open-source Chrome extension that uses your webcam and deeplearn.js to enable scrolling through webpages using custom gestures that you define.
JavaScript
33
star
16

font-explorer

Font latent space explorer using tensorflow.js
Vue
32
star
17

clinical-vis

A javascript medical record visualization (https://arxiv.org/abs/1810.05798)
HTML
26
star
18

megaplot

TypeScript
19
star
19

depth-maps-art-and-illusions

TypeScript
18
star
20

pair-code.github.io

HTML
18
star
21

farsight

In situ interactive widgets for responsible AI ๐ŸŒฑ
TypeScript
17
star
22

tiny-transformers

Jupyter Notebook
16
star
23

recommendation-rudders

TypeScript
13
star
24

covid19_symptom_dataset

JavaScript
12
star
25

thehardway

Supplementary code repository to accompany Tic-Tac-Toe the Hard Way podcast
JavaScript
11
star
26

jax-recommenders

Python
9
star
27

autonotes

AutoNotes is an experimental prototype for AI-powered notetaking, with features including hierarchical tagging, "chat with your notes," and highlights.
TypeScript
8
star
28

book-viz

Visualizing multilevel structure in books with sentence embeddings.
Jupyter Notebook
6
star
29

model-alignment

Model Alignment is a python library from the PAIR team that enable users to create model prompts through user feedback instead of manual prompt writing and editing. The technique makes use of constitutional principles to align prompts to users' desired values.
Python
6
star
30

waterfall-of-meaning

TypeScript
5
star
31

deliberate-lab

Platform for running online research experiments on human + LLM group dynamics.
TypeScript
4
star
32

deeplearnjs-legacy-loader

Deprecated: Legacy TensorFlow model loader for deeplearn.js
Python
3
star
33

colormap

JavaScript
3
star
34

adversarial-nibbler-vis

An interactive visualization interface for exploring and analyzing the Adversarial Nibbler dataset
TypeScript
3
star
35

auto-histograms

Python
2
star
36

ml-vis-experiments

Jupyter Notebook
1
star
37

deeplearnjs-docs

TypeScript
1
star