• Stars
    star
    127
  • Rank 282,790 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created about 5 years ago
  • Updated over 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Causal Explanation (CXPlain) is a method for explaining the predictions of any machine-learning model.

CXPlain

Code CoverageCode Coverage

Causal Explanations (CXPlain) is a method for explaining the decisions of any machine-learning model. CXPlain uses explanation models trained with a causal objective to learn to explain machine-learning models, and to quantify the uncertainty of its explanations. This repository contains a reference implementation for neural explanation models, and several practical examples for different data modalities. Please see the manuscript at https://arxiv.org/abs/1910.12336 (NeurIPS 2019) for a description and experimental evaluation of CXPlain.

Install

To install the latest release:

$ pip install cxplain

Use

A CXPlain model consists of four main components:

  • The model to be explained which can be any type of machine-learning model, including black-box models, such as neural networks and ensemble models.
  • The model builder that defines the structure of the explanation model to be used to explain the explained model.
  • The masking operation that defines how CXPlain will internally simulate the removal of input features from the set of available features.
  • The loss function that defines how the change in prediction accuracy incurred by removing an input feature will be measured by CXPlain.

After configuring these four components, you can fit a CXPlain instance to the same training data that was used to train your original model. The CXPlain instance can then explain any prediction of your explained model - even when no labels are available for that sample.

from tensorflow.python.keras.losses import categorical_crossentropy
from cxplain import MLPModelBuilder, ZeroMasking, CXPlain

x_train, y_train, x_test = ....  # Your dataset
explained_model = ...    # The model you wish to explain.

# Define the model you want to use to explain your __explained_model__.
# Here, we use a neural explanation model with a
# multilayer perceptron (MLP) architecture.
model_builder = MLPModelBuilder(num_layers=2, num_units=64, batch_size=256, learning_rate=0.001)

# Define your masking operation - the method of simulating the
# removal of input features used internally by CXPlain - ZeroMasking is typically a sensible default choice for tabular and image data.
masking_operation = ZeroMasking()

# Define the loss with which each input features' associated reduction in prediction error is calculated.
loss = categorical_crossentropy

# Build and fit a CXPlain instance.
explainer = CXPlain(explained_model, model_builder, masking_operation, loss)
explainer.fit(x_train, y_train)

# Use the __explainer__ to obtain explanations for the predictions of your __explained_model__.
attributions = explainer.explain(x_test)

Examples

More practical examples for various input data modalities, including images, textual data and tabular data, and both regression and classification tasks are provided in form of Jupyter notebooks in the examples/ directory:

MNIST ImageNet

Cite

Please consider citing, if you reference or use our methodology, code or results in your work:

@inproceedings{schwab2019cxplain,
  title={{CXPlain: Causal Explanations for Model Interpretation under Uncertainty}},
  author={Schwab, Patrick and Karlen, Walter},
  booktitle={{Advances in Neural Information Processing Systems (NeurIPS)}},
  year={2019}
}

License

MIT License

Acknowledgements

This work was partially funded by the Swiss National Science Foundation (SNSF) project No. 167302 within the National Research Program (NRP) 75 "Big Data". We gratefully acknowledge the support of NVIDIA Corporation with the donation of the Titan Xp GPUs used for this research. Patrick Schwab is an affiliated PhD fellow at the Max Planck ETH Center for Learning Systems.

More Repositories

1

perfect_match

βž•βž• Perfect Match is a simple method for learning representations for counterfactual inference with neural networks.
Python
121
star
2

drnet

πŸ’‰πŸ“ˆ Dose response networks (DRNets) are a method for learning to estimate individual dose-response curves for multiple parametric treatments from observational data using neural networks.
Python
82
star
3

ame

πŸ€–πŸ€– Attentive Mixtures of Experts (AMEs) are neural network models that learn to output both accurate predictions and estimates of feature importance for individual samples.
Python
40
star
4

CovEWS

The COVID-19 Early Warning System (CovEWS) is a real-time early warning system for assessing individual COVID-19 related mortality risk.
Python
17
star
5

heart_rhythm_attentive_rnn

β€οΈπŸ“± Heart rhythm classification from mobile event recorder data using attentive neural networks.
Python
9
star
6

DSMT-Nets

πŸ‘“πŸ“‘ Distantly Supervised Multitask Networks (DSMT-Nets) are a deep-learning approach to semi-supervised learning that utilises distant supervision through many auxiliary tasks.
Python
9
star
7

eth_dream_pd_subchallenge1

πŸšΆπŸ“± A deep-learning approach to automatically extract digital biomarkers for Parkinson's disease from smartphone accelerometers.
Python
8
star
8

ncore

Python
3
star
9

CGE-Piano

A piano scene for the computer graphics class project, using OpenGL, OpenAL, assimp and GLFW.
C++
2
star
10

SWPUE3-Plugin

A simple plugin pattern demonstration.
Java
1
star
11

SWPUE11-ChainOfResponsibility

A simple Chain Of Responsibility pattern demonstration.
Java
1
star
12

SWPUE5-Observer

A simple observer pattern demonstration.
C++
1
star
13

SWPUE7-Adapter

A simple adapter pattern demonstration.
Java
1
star
14

SWPUE1-Singleton

A simple singleton pattern demonstration
C++
1
star
15

SWPUE4-Composite

A simple composite pattern demonstration
C++
1
star