• Stars
    star
    509
  • Rank 86,772 (Top 2 %)
  • Language
    Python
  • License
    BSD 2-Clause "Sim...
  • Created almost 3 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Fast SHAP value computation for interpreting tree-based models

FastTreeSHAP

PyPI Version Downloads

FastTreeSHAP package is built based on the paper Fast TreeSHAP: Accelerating SHAP Value Computation for Trees published in NeurIPS 2021 XAI4Debugging Workshop. It is a fast implementation of the TreeSHAP algorithm in the SHAP package.

For more detailed introduction of FastTreeSHAP package, please check out this blogpost.

Introduction

SHAP (SHapley Additive exPlanation) values are one of the leading tools for interpreting machine learning models. Even though computing SHAP values takes exponential time in general, TreeSHAP takes polynomial time on tree-based models (e.g., decision trees, random forest, gradient boosted trees). While the speedup is significant, TreeSHAP can still dominate the computation time of industry-level machine learning solutions on datasets with millions or more entries.

In FastTreeSHAP package we implement two new algorithms, FastTreeSHAP v1 and FastTreeSHAP v2, designed to improve the computational efficiency of TreeSHAP for large datasets. We empirically find that Fast TreeSHAP v1 is 1.5x faster than TreeSHAP while keeping the memory cost unchanged, and Fast TreeSHAP v2 is 2.5x faster than TreeSHAP, at the cost of a slightly higher memory usage (performance is measured on a single core).

The table below summarizes the time and space complexities of each variant of TreeSHAP algorithm ( is the number of samples to be explained, is the number of features, is the number of trees, is the maximum number of leaves in any tree, and is the maximum depth of any tree). Note that the (theoretical) average running time of FastTreeSHAP v1 is reduced to 25% of TreeSHAP.

TreeSHAP Version Time Complexity Space Complexity
TreeSHAP
FastTreeSHAP v1
FastTreeSHAP v2 (general case)
FastTreeSHAP v2 (balanced trees)

Performance with Parallel Computing

Parallel computing is fully enabled in FastTreeSHAP package. As a comparison, parallel computing is not enabled in SHAP package except for "shortcut" which calls TreeSHAP algorithms embedded in XGBoost, LightGBM, and CatBoost packages specifically for these three models.

The table below compares the execution times of FastTreeSHAP v1 and FastTreeSHAP v2 in FastTreeSHAP package against TreeSHAP algorithm (or "shortcut") in SHAP package on two datasets Adult (binary classification) and Superconductor (regression). All the evaluations were run in parallel on all available cores in Azure Virtual Machine with size Standard_D8_v3 (8 cores and 32GB memory) (except for scikit-learn models in SHAP package). We ran each evaluation on 10,000 samples, and the results were averaged over 3 runs.

Model # Trees Tree
Depth
Dataset SHAP (s) FastTree-
SHAP v1 (s)
Speedup FastTree-
SHAP v2 (s)
Speedup
sklearn random forest 500 8 Adult 318.44* 43.89 7.26 27.06 11.77
sklearn random forest 500 8 Super 466.04 58.28 8.00 36.56 12.75
sklearn random forest 500 12 Adult 2446.12 293.75 8.33 158.93 15.39
sklearn random forest 500 12 Super 5282.52 585.85 9.02 370.09 14.27
XGBoost 500 8 Adult 17.35** 12.31 1.41 6.53 2.66
XGBoost 500 8 Super 35.31 21.09 1.67 13.00 2.72
XGBoost 500 12 Adult 62.19 40.31 1.54 21.34 2.91
XGBoost 500 12 Super 152.23 82.46 1.85 51.47 2.96
LightGBM 500 8 Adult 7.64*** 7.20 1.06 3.24 2.36
LightGBM 500 8 Super 8.73 7.11 1.23 3.58 2.44
LightGBM 500 12 Adult 9.95 7.96 1.25 4.02 2.48
LightGBM 500 12 Super 14.02 11.14 1.26 4.81 2.91

* Parallel computing is not enabled in SHAP package for scikit-learn models, thus TreeSHAP algorithm runs on a single core.
** SHAP package calls TreeSHAP algorithm in XGBoost package, which by default enables parallel computing on all cores.
*** SHAP package calls TreeSHAP algorithm in LightGBM package, which by default enables parallel computing on all cores.

Installation

FastTreeSHAP package is available on PyPI and can be installed with pip:

pip install fasttreeshap

Installation troubleshooting:

  • On Macbook, if an error message ld: library not found for -lomp pops up, run the following command line before installation (Reference):
brew install libomp

Usage

The following screenshot shows a typical use case of FastTreeSHAP on Census Income Data. Note that the usage of FastTreeSHAP is exactly the same as the usage of SHAP, except for four additional arguments in the class TreeExplainer: algorithm, n_jobs, memory_tolerance, and shortcut.

algorithm: This argument specifies the TreeSHAP algorithm used to run FastTreeSHAP. It can take values "v0", "v1", "v2" or "auto", and its default value is "auto":

  • "v0": Original TreeSHAP algorithm in SHAP package.
  • "v1": FastTreeSHAP v1 algorithm proposed in FastTreeSHAP paper.
  • "v2": FastTreeSHAP v2 algorithm proposed in FastTreeSHAP paper.
  • "auto" (default): Automatic selection between "v0", "v1" and "v2" according to the number of samples to be explained and the constraint on the allocated memory. Specifically, "v1" is always preferred to "v0" in any use cases, and "v2" is preferred to "v1" when the number of samples to be explained is sufficiently large (), and the memory constraint is also satisfied (, is the number of threads). More detailed discussion of the above criteria can be found in FastTreeSHAP paper and in Section Notes.

n_jobs: This argument specifies the number of parallel threads used to run FastTreeSHAP. It can take values -1 or a positive integer. Its default value is -1, which means utilizing all available cores in parallel computing.

memory_tolerance: This argument specifies the upper limit of memory allocation (in GB) to run FastTreeSHAP v2. It can take values -1 or a positive number. Its default value is -1, which means allocating a maximum of 0.25 * total memory of the machine to run FastTreeSHAP v2.

shortcut: This argument determines whether to use the TreeSHAP algorithm embedded in XGBoost, LightGBM, and CatBoost packages directly when computing SHAP values for XGBoost, LightGBM, and CatBoost models and when computing SHAP interaction values for XGBoost models. Its default value is False, which means bypassing the "shortcut" and using the code in FastTreeSHAP package directly to compute SHAP values for XGBoost, LightGBM, and CatBoost models. Note that currently shortcut is automaticaly set to be True for CatBoost model, as we are working on CatBoost component in FastTreeSHAP package. More details of the usage of "shortcut" can be found in the notebooks Census Income, Superconductor, and Crop Mapping.

FastTreeSHAP Adult Screenshot1

The code in the following screenshot was run on all available cores in a Macbook Pro (2.4 GHz 8-Core Intel Core i9 and 32GB Memory). We see that both "v1" and "v2" produce exactly the same SHAP value results as "v0". Meanwhile, "v2" has the shortest execution time, followed by "v1", and then "v0". "auto" selects "v2" as the most appropriate algorithm in this use case as desired. For more detailed comparisons between FastTreeSHAP v1, FastTreeSHAP v2 and the original TreeSHAP, check the notebooks Census Income, Superconductor, and Crop Mapping.

FastTreeSHAP Adult Screenshot2

Notes

  • In FastTreeSHAP paper, two scenarios in model interpretation use cases have been discussed: one-time usage (explaining all samples for once), and multi-time usage (having a stable model in the backend and receiving new scoring data to be explained on a regular basis). Current version of FastTreeSHAP package only supports one-time usage scenario, and we are working on extending it to multi-time usage scenario with parallel computing. Evaluation results in FastTreeSHAP paper shows that FastTreeSHAP v2 can achieve as high as 3x faster explanation in multi-time usage scenario.
  • The implementation of parallel computing is straightforward for FastTreeSHAP v1 and the original TreeSHAP, where a parallel for-loop is built over all samples. The implementation of parallel computing for FastTreeSHAP v2 is slightly more complicated: Two versions of parallel computing have been implemented. Version 1 builds a parallel for-loop over all trees, which requires memory allocation (each thread has its own matrices to store both SHAP values and pre-computed values). Version 2 builds two consecutive parallel for-loops over all trees and over all samples respectively, which requires memory allocation (first parallel for-loop stores pre-computed values across all trees). In FastTreeSHAP package, version 1 is selected for FastTreeSHAP v2 as long as its memory constraint is satisfied. If not, version 2 is selected as an alternative as long as its memory constraint is satisfied. If the memory constraints in both version 1 and version 2 are not satisfied, FastTreeSHAP v1 will replace FastTreeSHAP v2 with a less strict memory constraint.

Notebooks

The notebooks below contain more detailed comparisons between FastTreeSHAP v1, FastTreeSHAP v2 and the original TreeSHAP in classification and regression problems using scikit-learn, XGBoost and LightGBM:

Citation

Please cite FastTreeSHAP in your publications if it helps your research:

@article{yang2021fast,
  title={Fast TreeSHAP: Accelerating SHAP Value Computation for Trees},
  author={Yang, Jilei},
  journal={arXiv preprint arXiv:2109.09847},
  year={2021}
}

License

Copyright (c) LinkedIn Corporation. All rights reserved. Licensed under the BSD 2-Clause License.

More Repositories

1

school-of-sre

At LinkedIn, we are using this curriculum for onboarding our entry-level talents into the SRE role.
HTML
7,821
star
2

css-blocks

High performance, maintainable stylesheets.
TypeScript
6,335
star
3

Burrow

Kafka Consumer Lag Checking
Go
3,725
star
4

databus

Source-agnostic distributed change data capture system
Java
3,636
star
5

Liger-Kernel

Efficient Triton Kernels for LLM Training
Python
3,312
star
6

qark

Tool to look for several security related Android application vulnerabilities
Python
3,183
star
7

dustjs

Asynchronous Javascript templating for the browser and server
JavaScript
2,911
star
8

cruise-control

Cruise-control is the first of its kind to fully automate the dynamic workload rebalance and self-healing of a Kafka cluster. It provides great value to Kafka users by simplifying the operation of Kafka clusters.
Java
2,734
star
9

rest.li

Rest.li is a REST+JSON framework for building robust, scalable service architectures using dynamic discovery and simple asynchronous APIs.
Java
2,500
star
10

kafka-monitor

Xinfra Monitor monitors the availability of Kafka clusters by producing synthetic workloads using end-to-end pipelines to obtain derived vital statistics - E2E latency, service produce/consume availability, offsets commit availability & latency, message loss rate and more.
Java
2,016
star
11

dexmaker

A utility for doing compile or runtime code generation targeting Android's Dalvik VM
Java
1,863
star
12

greykite

A flexible, intuitive and fast forecasting library
Python
1,813
star
13

ambry

Distributed object store
Java
1,740
star
14

shiv

shiv is a command line utility for building fully self contained Python zipapps as outlined in PEP 441, but with all their dependencies included.
Python
1,729
star
15

swift-style-guide

LinkedIn's Official Swift Style Guide
1,430
star
16

dr-elephant

Dr. Elephant is a job and flow-level performance monitoring and tuning tool for Apache Hadoop and Apache Spark
Java
1,353
star
17

detext

DeText: A Deep Neural Text Understanding Framework for Ranking and Classification Tasks
Python
1,263
star
18

luminol

Anomaly Detection and Correlation library
Python
1,182
star
19

parseq

Asynchronous Java made easier
Java
1,165
star
20

oncall

Oncall is a calendar tool designed for scheduling and managing on-call shifts. It can be used as source of dynamic ownership info for paging systems like http://iris.claims.
Python
1,137
star
21

test-butler

Reliable Android Testing, at your service
Java
1,046
star
22

goavro

Go
972
star
23

PalDB

An embeddable write-once key-value store written in Java
Java
937
star
24

brooklin

An extensible distributed system for reliable nearline data streaming at scale
Java
919
star
25

iris

Iris is a highly configurable and flexible service for paging and messaging.
Python
807
star
26

photon-ml

A scalable machine learning library on Apache Spark
Terra
793
star
27

URL-Detector

A Java library to detect and normalize URLs in text
Java
782
star
28

coral

Coral is a translation, analysis, and query rewrite engine for SQL and other relational languages.
Java
781
star
29

Hakawai

A powerful, extensible UITextView.
Objective-C
781
star
30

eyeglass

NPM Modules for Sass
TypeScript
741
star
31

opticss

A CSS Optimizer
TypeScript
715
star
32

LiTr

Lightweight hardware accelerated video/audio transcoder for Android.
Java
609
star
33

kafka-tools

A collection of tools for working with Apache Kafka.
Python
592
star
34

pygradle

Using Gradle to build Python projects
Java
587
star
35

flashback

mock the internet
Java
578
star
36

FeatureFu

Library and tools for advanced feature engineering
Java
568
star
37

LayoutTest-iOS

Write unit tests which test the layout of a view in multiple configurations
Objective-C
564
star
38

venice

Venice, Derived Data Platform for Planet-Scale Workloads.
Java
487
star
39

Spyglass

A library for mentions on Android
Java
386
star
40

dagli

Framework for defining machine learning models, including feature generation and transformations, as directed acyclic graphs (DAGs).
Java
353
star
41

cruise-control-ui

Cruise Control Frontend (CCFE): Single Page Web Application to Manage Large Scale of Kafka Clusters
Vue
337
star
42

ml-ease

ADMM based large scale logistic regression
Java
333
star
43

openhouse

Open Control Plane for Tables in Data Lakehouse
Java
304
star
44

dph-framework

HTML
298
star
45

transport

A framework for writing performant user-defined functions (UDFs) that are portable across a variety of engines including Apache Spark, Apache Hive, and Presto.
Java
296
star
46

spark-tfrecord

Read and write Tensorflow TFRecord data from Apache Spark.
Scala
288
star
47

isolation-forest

A Spark/Scala implementation of the isolation forest unsupervised outlier detection algorithm with support for exporting in ONNX format.
Scala
224
star
48

LiFT

The LinkedIn Fairness Toolkit (LiFT) is a Scala/Spark library that enables the measurement of fairness in large scale machine learning workflows.
Scala
168
star
49

shaky-android

Shake to send feedback for Android.
Java
160
star
50

pyexchange

Python wrapper for Microsoft Exchange
Python
153
star
51

asciietch

A graphing library with the goal of making it simple to graphs using ascii characters.
Python
138
star
52

python-avro-json-serializer

Serializes data into a JSON format using AVRO schema.
Python
137
star
53

gdmix

A deep ranking personalization framework
Python
131
star
54

li-apache-kafka-clients

li-apache-kafka-clients is a wrapper library for the Apache Kafka vanilla clients. It provides additional features such as large message support and auditing to the Java producer and consumer in the open source Apache Kafka.
Java
131
star
55

dynamometer

A tool for scale and performance testing of HDFS with a specific focus on the NameNode.
Java
131
star
56

Avro2TF

Avro2TF is designed to fill the gap of making users' training data ready to be consumed by deep learning training frameworks.
Scala
126
star
57

datahub-gma

General Metadata Architecture
Java
121
star
58

linkedin-gradle-plugin-for-apache-hadoop

Groovy
117
star
59

dex-test-parser

Find all test methods in an Android instrumentation APK
Kotlin
106
star
60

cassette

An efficient, file-based FIFO Queue for iOS and macOS.
Objective-C
95
star
61

spaniel

LinkedIn's JavaScript viewport tracking library and IntersectionObserver polyfill
JavaScript
92
star
62

Hoptimator

Multi-hop declarative data pipelines
Java
91
star
63

migz

Multithreaded, gzip-compatible compression and decompression, available as a platform-independent Java library and command-line utilities.
Java
79
star
64

avro-util

Collection of utilities to allow writing java code that operates across a wide range of avro versions.
Java
76
star
65

sysops-api

sysops-api is a framework designed to provide visability from tens of thousands of machines in seconds.
Python
74
star
66

iceberg

A temporary home for LinkedIn's changes to Apache Iceberg (incubating)
Java
62
star
67

DuaLip

DuaLip: Dual Decomposition based Linear Program Solver
Scala
59
star
68

kube2hadoop

Secure HDFS Access from Kubernetes
Java
59
star
69

dynoyarn

DynoYARN is a framework to run simulated YARN clusters and workloads for YARN scale testing.
Java
58
star
70

linkedin.github.com

Listing of all our public GitHub projects.
JavaScript
58
star
71

Tachyon

An Android library that provides a customizable calendar day view UI widget.
Java
57
star
72

Cytodynamics

Classloader isolation library.
Java
49
star
73

iris-relay

Stateless reverse proxy for thirdparty service integration with Iris API.
Python
48
star
74

concurrentli

Classes for multithreading that expand on java.util.concurrent, adding convenience, efficiency and new tools to multithreaded Java programs
Java
46
star
75

iris-mobile

A mobile interface for linkedin/iris, built for iOS and Android on the Ionic platform
TypeScript
42
star
76

lambda-learner

Lambda Learner is a library for iterative incremental training of a class of supervised machine learning models.
Python
41
star
77

TE2Rules

Python library to explain Tree Ensemble models (TE) like XGBoost, using a rule list.
Python
40
star
78

instantsearch-tutorial

Sample code for building an end-to-end instant search solution
JavaScript
39
star
79

PASS-GNN

Python
38
star
80

self-focused

Helps make a single page application more friendly to screen readers.
JavaScript
35
star
81

tracked-queue

An autotracked implementation of a ring-buffer-backed double-ended queue
TypeScript
35
star
82

QueryAnalyzerAgent

Analyze MySQL queries with negligible overhead
Go
35
star
83

performance-quality-models

Personalizing Performance model repository
Jupyter Notebook
31
star
84

data-integration-library

The Data Integration Library project provides a library of generic components based on a multi-stage architecture for data ingress and egress.
Java
28
star
85

Iris-message-processor

Iris-message-processor is a fully distributed Go application meant to replace the sender functionality of Iris and provide reliable, scalable, and extensible incident and out of band message processing and sending.
Go
27
star
86

smart-arg

Smart Arguments Suite (smart-arg) is a slim and handy python lib that helps one work safely and conveniently with command line arguments.
Python
23
star
87

linkedin-calcite

LinkedIn's version of Apache Calcite
Java
22
star
88

atscppapi

This library provides wrappers around the existing Apache Traffic Server API which will vastly simplify the process of writing Apache Traffic Server plugins.
C++
20
star
89

forthic

Python
18
star
90

high-school-trainee

LinkedIn Women in Tech High School Trainee Program
Python
18
star
91

play-parseq

Play-ParSeq is a Play module which seamlessly integrates ParSeq with Play Framework
Scala
17
star
92

icon-magic

Automated icon build system for iOS, Android and Web
TypeScript
17
star
93

QuantEase

QuantEase, a layer-wise quantization framework, frames the problem as discrete-structured non-convex optimization. Our work leverages Coordinate Descent techniques, offering high-quality solutions without the need for matrix inversion or decomposition.
Python
17
star
94

kafka-remote-storage-azure

Java
13
star
95

play-restli

A library that simplifies building restli services on top of the play server.
Java
12
star
96

spark-inequality-impact

Scala
12
star
97

Li-Airflow-Backfill-Plugin

Li-Airflow-Backfill-Plugin is a plugin to work with Apache Airflow to provide data backfill feature, ie. to rerun pipelines for a certain date range.
Python
10
star
98

AlerTiger

Jupyter Notebook
9
star
99

diderot

A fast and flexible implementation of the xDS protocol
Go
6
star
100

gobblin-elr

This is a read-only mirror of apache/gobblin
Java
5
star