• Stars
    star
    4,900
  • Rank 8,370 (Top 0.2 %)
  • Language
    Python
  • License
    Other
  • Created about 5 years ago
  • Updated 8 days ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Uplift modeling and causal inference with machine learning algorithms

PyPI Version Build Status Documentation Status Downloads CII Best Practices

Disclaimer

This project is stable and being incubated for long-term support. It may contain new experimental code, for which APIs are subject to change.

Causal ML: A Python Package for Uplift Modeling and Causal Inference with ML

Causal ML is a Python package that provides a suite of uplift modeling and causal inference methods using machine learning algorithms based on recent research [1]. It provides a standard interface that allows user to estimate the Conditional Average Treatment Effect (CATE) or Individual Treatment Effect (ITE) from experimental or observational data. Essentially, it estimates the causal impact of intervention T on outcome Y for users with observed features X, without strong assumptions on the model form. Typical use cases include

  • Campaign targeting optimization: An important lever to increase ROI in an advertising campaign is to target the ad to the set of customers who will have a favorable response in a given KPI such as engagement or sales. CATE identifies these customers by estimating the effect of the KPI from ad exposure at the individual level from A/B experiment or historical observational data.

  • Personalized engagement: A company has multiple options to interact with its customers such as different product choices in up-sell or messaging channels for communications. One can use CATE to estimate the heterogeneous treatment effect for each customer and treatment option combination for an optimal personalized recommendation system.

The package currently supports the following methods

  • Tree-based algorithms
    • Uplift tree/random forests on KL divergence, Euclidean Distance, and Chi-Square [2]
    • Uplift tree/random forests on Contextual Treatment Selection [3]
    • Uplift tree/random forests on DDP [4]
    • Uplift tree/random forests on IDDP [5]
    • Interaction Tree [6]
    • Conditional Interaction Tree [7]
    • Causal Tree [8] - Work-in-progress
  • Meta-learner algorithms
    • S-learner [9]
    • T-learner [9]
    • X-learner [9]
    • R-learner [10]
    • Doubly Robust (DR) learner [11]
    • TMLE learner [12]
  • Instrumental variables algorithms
    • 2-Stage Least Squares (2SLS)
    • Doubly Robust (DR) IV [13]
  • Neural-network-based algorithms

Installation

Installation with conda is recommended. conda environment files for Python 3.7, 3.8 and 3.9 are available in the repository. To use models under the inference.tf module (e.g. DragonNet), additional dependency of tensorflow is required. For detailed instructions, see below.

Install using conda:

Install from conda-forge

Directly install from the conda-forge channel using conda.

$ conda install -c conda-forge causalml

Install with the conda virtual environment

This will create a new conda virtual environment named causalml-[tf-]py3x, where x is in [6, 7, 8, 9]. e.g. causalml-py37 or causalml-tf-py38. If you want to change the name of the environment, update the relevant YAML file in envs/

$ git clone https://github.com/uber/causalml.git
$ cd causalml/envs/
$ conda env create -f environment-py38.yml	# for the virtual environment with Python 3.8 and CausalML
$ conda activate causalml-py38
(causalml-py38)

Install causalml with tensorflow

$ git clone https://github.com/uber/causalml.git
$ cd causalml/envs/
$ conda env create -f environment-tf-py38.yml	# for the virtual environment with Python 3.8 and CausalML
$ conda activate causalml-tf-py38
(causalml-tf-py38) pip install -U numpy			# this step is necessary to fix [#338](https://github.com/uber/causalml/issues/338)

Install from PyPI:

$ pip install causalml

Install causalml with tensorflow

$ pip install causalml[tf]
$ pip install -U numpy							# this step is necessary to fix [#338](https://github.com/uber/causalml/issues/338)

Install from source:

$ git clone https://github.com/uber/causalml.git
$ cd causalml
$ pip install .

with tensorflow:

pip install .[tf]

Quick Start

Average Treatment Effect Estimation with S, T, X, and R Learners

from causalml.inference.meta import LRSRegressor
from causalml.inference.meta import XGBTRegressor, MLPTRegressor
from causalml.inference.meta import BaseXRegressor
from causalml.inference.meta import BaseRRegressor
from xgboost import XGBRegressor
from causalml.dataset import synthetic_data

y, X, treatment, _, _, e = synthetic_data(mode=1, n=1000, p=5, sigma=1.0)

lr = LRSRegressor()
te, lb, ub = lr.estimate_ate(X, treatment, y)
print('Average Treatment Effect (Linear Regression): {:.2f} ({:.2f}, {:.2f})'.format(te[0], lb[0], ub[0]))

xg = XGBTRegressor(random_state=42)
te, lb, ub = xg.estimate_ate(X, treatment, y)
print('Average Treatment Effect (XGBoost): {:.2f} ({:.2f}, {:.2f})'.format(te[0], lb[0], ub[0]))

nn = MLPTRegressor(hidden_layer_sizes=(10, 10),
                 learning_rate_init=.1,
                 early_stopping=True,
                 random_state=42)
te, lb, ub = nn.estimate_ate(X, treatment, y)
print('Average Treatment Effect (Neural Network (MLP)): {:.2f} ({:.2f}, {:.2f})'.format(te[0], lb[0], ub[0]))

xl = BaseXRegressor(learner=XGBRegressor(random_state=42))
te, lb, ub = xl.estimate_ate(X, treatment, y, e)
print('Average Treatment Effect (BaseXRegressor using XGBoost): {:.2f} ({:.2f}, {:.2f})'.format(te[0], lb[0], ub[0]))

rl = BaseRRegressor(learner=XGBRegressor(random_state=42))
te, lb, ub =  rl.estimate_ate(X=X, p=e, treatment=treatment, y=y)
print('Average Treatment Effect (BaseRRegressor using XGBoost): {:.2f} ({:.2f}, {:.2f})'.format(te[0], lb[0], ub[0]))

See the Meta-learner example notebook for details.

Interpretable Causal ML

Causal ML provides methods to interpret the treatment effect models trained as follows:

Meta Learner Feature Importances

from causalml.inference.meta import BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor
from causalml.dataset.regression import synthetic_data

# Load synthetic data
y, X, treatment, tau, b, e = synthetic_data(mode=1, n=10000, p=25, sigma=0.5)
w_multi = np.array(['treatment_A' if x==1 else 'control' for x in treatment]) # customize treatment/control names

slearner = BaseSRegressor(LGBMRegressor(), control_name='control')
slearner.estimate_ate(X, w_multi, y)
slearner_tau = slearner.fit_predict(X, w_multi, y)

model_tau_feature = RandomForestRegressor()  # specify model for model_tau_feature

slearner.get_importance(X=X, tau=slearner_tau, model_tau_feature=model_tau_feature,
                        normalize=True, method='auto', features=feature_names)

# Using the feature_importances_ method in the base learner (LGBMRegressor() in this example)
slearner.plot_importance(X=X, tau=slearner_tau, normalize=True, method='auto')

# Using eli5's PermutationImportance
slearner.plot_importance(X=X, tau=slearner_tau, normalize=True, method='permutation')

# Using SHAP
shap_slearner = slearner.get_shap_values(X=X, tau=slearner_tau)

# Plot shap values without specifying shap_dict
slearner.plot_shap_values(X=X, tau=slearner_tau)

# Plot shap values WITH specifying shap_dict
slearner.plot_shap_values(X=X, shap_dict=shap_slearner)

# interaction_idx set to 'auto' (searches for feature with greatest approximate interaction)
slearner.plot_shap_dependence(treatment_group='treatment_A',
                              feature_idx=1,
                              X=X,
                              tau=slearner_tau,
                              interaction_idx='auto')

See the feature interpretations example notebook for details.

Uplift Tree Visualization

from IPython.display import Image
from causalml.inference.tree import UpliftTreeClassifier, UpliftRandomForestClassifier
from causalml.inference.tree import uplift_tree_string, uplift_tree_plot

uplift_model = UpliftTreeClassifier(max_depth=5, min_samples_leaf=200, min_samples_treatment=50,
                                    n_reg=100, evaluationFunction='KL', control_name='control')

uplift_model.fit(df[features].values,
                 treatment=df['treatment_group_key'].values,
                 y=df['conversion'].values)

graph = uplift_tree_plot(uplift_model.fitted_uplift_tree, features)
Image(graph.create_png())

See the Uplift Tree visualization example notebook for details.

Contributing

We welcome community contributors to the project. Before you start, please read our code of conduct and check out contributing guidelines first.

Versioning

We document versions and changes in our changelog.

License

This project is licensed under the Apache 2.0 License - see the LICENSE file for details.

References

Documentation

Conference Talks and Publications by CausalML Team

Citation

To cite CausalML in publications, you can refer to the following sources:

Whitepaper: CausalML: Python Package for Causal Machine Learning

Bibtex:

@misc{chen2020causalml, title={CausalML: Python Package for Causal Machine Learning}, author={Huigang Chen and Totte Harinen and Jeong-Yoon Lee and Mike Yung and Zhenyu Zhao}, year={2020}, eprint={2002.11631}, archivePrefix={arXiv}, primaryClass={cs.CY} }

Literature

  1. Chen, Huigang, Totte Harinen, Jeong-Yoon Lee, Mike Yung, and Zhenyu Zhao. "Causalml: Python package for causal machine learning." arXiv preprint arXiv:2002.11631 (2020).
  2. Radcliffe, Nicholas J., and Patrick D. Surry. "Real-world uplift modelling with significance-based uplift trees." White Paper TR-2011-1, Stochastic Solutions (2011): 1-33.
  3. Zhao, Yan, Xiao Fang, and David Simchi-Levi. "Uplift modeling with multiple treatments and general response types." Proceedings of the 2017 SIAM International Conference on Data Mining. Society for Industrial and Applied Mathematics, 2017.
  4. Hansotia, Behram, and Brad Rukstales. "Incremental value modeling." Journal of Interactive Marketing 16.3 (2002): 35-46.
  5. Jannik RรถรŸler, Richard Guse, and Detlef Schoder. "The Best of Two Worlds: Using Recent Advances from Uplift Modeling and Heterogeneous Treatment Effects to Optimize Targeting Policies". International Conference on Information Systems (2022)
  6. Su, Xiaogang, et al. "Subgroup analysis via recursive partitioning." Journal of Machine Learning Research 10.2 (2009).
  7. Su, Xiaogang, et al. "Facilitating score and causal inference trees for large observational studies." Journal of Machine Learning Research 13 (2012): 2955.
  8. Athey, Susan, and Guido Imbens. "Recursive partitioning for heterogeneous causal effects." Proceedings of the National Academy of Sciences 113.27 (2016): 7353-7360.
  9. Kรผnzel, Sรถren R., et al. "Metalearners for estimating heterogeneous treatment effects using machine learning." Proceedings of the national academy of sciences 116.10 (2019): 4156-4165.
  10. Nie, Xinkun, and Stefan Wager. "Quasi-oracle estimation of heterogeneous treatment effects." arXiv preprint arXiv:1712.04912 (2017).
  11. Bang, Heejung, and James M. Robins. "Doubly robust estimation in missing data and causal inference models." Biometrics 61.4 (2005): 962-973.
  12. Van Der Laan, Mark J., and Daniel Rubin. "Targeted maximum likelihood learning." The international journal of biostatistics 2.1 (2006).
  13. Kennedy, Edward H. "Optimal doubly robust estimation of heterogeneous causal effects." arXiv preprint arXiv:2004.14497 (2020).
  14. Louizos, Christos, et al. "Causal effect inference with deep latent-variable models." arXiv preprint arXiv:1705.08821 (2017).
  15. Shi, Claudia, David M. Blei, and Victor Veitch. "Adapting neural networks for the estimation of treatment effects." 33rd Conference on Neural Information Processing Systems (NeurIPS 2019), 2019.
  16. Zhao, Zhenyu, Yumin Zhang, Totte Harinen, and Mike Yung. "Feature Selection Methods for Uplift Modeling." arXiv preprint arXiv:2005.03447 (2020).
  17. Zhao, Zhenyu, and Totte Harinen. "Uplift modeling for multiple treatments with cost optimization." In 2019 IEEE International Conference on Data Science and Advanced Analytics (DSAA), pp. 422-431. IEEE, 2019.

Related projects

  • uplift: uplift models in R
  • grf: generalized random forests that include heterogeneous treatment effect estimation in R
  • rlearner: A R package that implements R-Learner
  • DoWhy: Causal inference in Python based on Judea Pearl's do-calculus
  • EconML: A Python package that implements heterogeneous treatment effect estimators from econometrics and machine learning methods

More Repositories

1

react-vis

Data Visualization Components
JavaScript
8,705
star
2

baseweb

A React Component library implementing the Base design language
TypeScript
8,666
star
3

cadence

Cadence is a distributed, scalable, durable, and highly available orchestration engine to execute asynchronous long-running business logic in a scalable and resilient way.
Go
8,008
star
4

RIBs

Uber's cross-platform mobile architecture framework.
Kotlin
7,719
star
5

kraken

P2P Docker registry capable of distributing TBs of data in seconds
Go
6,005
star
6

prototool

Your Swiss Army Knife for Protocol Buffers
Go
5,044
star
7

h3

Hexagonal hierarchical geospatial indexing system
C
4,743
star
8

NullAway

A tool to help eliminate NullPointerExceptions (NPEs) in your Java code with low build-time overhead
Java
3,598
star
9

AutoDispose

Automatic binding+disposal of RxJava streams.
Java
3,365
star
10

aresdb

A GPU-powered real-time analytics storage and query engine.
Go
2,996
star
11

react-digraph

A library for creating directed graph editors
JavaScript
2,605
star
12

piranha

A tool for refactoring code related to feature flag APIs
Java
2,250
star
13

orbit

A Python package for Bayesian forecasting with object-oriented design and probabilistic models under the hood.
Python
1,835
star
14

ios-snapshot-test-case

Snapshot view unit tests for iOS
Objective-C
1,788
star
15

needle

Compile-time safe Swift dependency injection framework
Swift
1,785
star
16

petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Python
1,770
star
17

manifold

A model-agnostic visual debugging tool for machine learning
JavaScript
1,642
star
18

okbuck

OkBuck is a gradle plugin that lets developers utilize the Buck build system on a gradle project.
Java
1,536
star
19

UberSignature

Provides an iOS view controller allowing a user to draw their signature with their finger in a realistic style.
Objective-C
1,286
star
20

nanoscope

An extremely accurate Android method tracing tool.
HTML
1,245
star
21

tchannel

network multiplexing and framing protocol for RPC
Thrift
1,151
star
22

queryparser

Parsing and analysis of Vertica, Hive, and Presto SQL.
Haskell
1,072
star
23

fiber

Distributed Computing for AI Made Simple
Python
1,039
star
24

neuropod

A uniform interface to run deep learning models from multiple frameworks
C++
932
star
25

uReplicator

Improvement of Apache Kafka Mirrormaker
Java
907
star
26

pam-ussh

uber's ssh certificate pam module
Go
841
star
27

h3-js

h3-js provides a JavaScript version of H3, a hexagon-based geospatial indexing system.
JavaScript
828
star
28

ringpop-go

Scalable, fault-tolerant application-layer sharding for Go applications
Go
822
star
29

mockolo

Efficient Mock Generator for Swift
Swift
805
star
30

h3-py

Python bindings for H3, a hierarchical hexagonal geospatial indexing system
Python
794
star
31

xviz

A protocol for real-time transfer and visualization of autonomy data
JavaScript
760
star
32

streetscape.gl

Visualization framework for autonomy and robotics data encoded in XVIZ
JavaScript
702
star
33

react-view

React View is an interactive playground, documentation and code generator for your components.
TypeScript
690
star
34

nebula.gl

A suite of 3D-enabled data editing overlays, suitable for deck.gl
TypeScript
682
star
35

RxDogTag

Automatic tagging of RxJava 2+ originating subscribe points for onError() investigation.
Java
648
star
36

peloton

Unified Resource Scheduler to co-schedule mixed types of workloads such as batch, stateless and stateful jobs in a single cluster for better resource utilization.
Go
638
star
37

motif

A simple DI API for Android / Java
Kotlin
531
star
38

signals-ios

Typeful eventing
Objective-C
528
star
39

grafana-dash-gen

grafana dash dash dash gen
JavaScript
482
star
40

tchannel-go

Go implementation of a multiplexing and framing protocol for RPC calls
Go
481
star
41

marmaray

Generic Data Ingestion & Dispersal Library for Hadoop
Java
475
star
42

zanzibar

A build system & configuration system to generate versioned API gateways.
Go
455
star
43

clay

Clay is a framework for building RESTful backend services using best practices. Itโ€™s a wrapper around Flask.
Python
441
star
44

astro

Astro is a tool for managing multiple Terraform executions as a single command
Go
434
star
45

NEAL

๐Ÿ”Ž๐Ÿž A language-agnostic linting platform
OCaml
426
star
46

react-vis-force

d3-force graphs as React Components.
JavaScript
402
star
47

arachne

An always-on framework that performs end-to-end functional network testing for reachability, latency, and packet loss
Go
392
star
48

cadence-web

Web UI for visualizing workflows on Cadence
JavaScript
381
star
49

Python-Sample-Application

Python
377
star
50

rides-ios-sdk

Uber Rides iOS SDK (beta)
Swift
370
star
51

stylist

A stylist creates cool styles. Stylist is a Gradle plugin that codegens a base set of Android XML themes.
Kotlin
357
star
52

storagetapper

StorageTapper is a scalable realtime MySQL change data streaming, logical backup and logical replication service
Go
336
star
53

swift-concurrency

Concurrency utilities for Swift
Swift
326
star
54

RemoteShuffleService

Remote shuffle service for Apache Spark to store shuffle data on remote servers.
Java
319
star
55

cyborg

Display Android Vectordrawables on iOS.
Swift
300
star
56

h3-go

Go bindings for H3, a hierarchical hexagonal geospatial indexing system
Go
293
star
57

rides-android-sdk

Uber Rides Android SDK (beta)
Java
291
star
58

hermetic_cc_toolchain

Bazel C/C++ toolchain for cross-compiling C/C++ programs
Starlark
272
star
59

h3-java

Java bindings for H3, a hierarchical hexagonal geospatial indexing system
Java
268
star
60

h3-py-notebooks

Jupyter notebooks for h3-py, a hierarchical hexagonal geospatial indexing system
Jupyter Notebook
253
star
61

geojson2h3

Conversion utilities between H3 indexes and GeoJSON
JavaScript
221
star
62

artist

An artist creates views. Artist is a Gradle plugin that codegens a base set of Android Views.
Kotlin
211
star
63

tchannel-node

JavaScript
203
star
64

RxCentralBle

A reactive, interface-driven central role Bluetooth LE library for Android
Java
197
star
65

uberalls

Track code coverage metrics with Jenkins and Phabricator
Go
186
star
66

SwiftCodeSan

SwiftCodeSan is a tool that "sanitizes" code written in Swift.
Swift
176
star
67

rides-python-sdk

Uber Rides Python SDK (beta)
Python
171
star
68

doubles

Test doubles for Python.
Python
165
star
69

logtron

A logging MACHINE
JavaScript
158
star
70

athenadriver

A fully-featured AWS Athena database driver (+ athenareader https://github.com/uber/athenadriver/tree/master/athenareader)
Go
146
star
71

cadence-java-client

Java framework for Cadence Workflow Service
Java
140
star
72

cassette

Store and replay HTTP requests made in your Python app
Python
138
star
73

UBTokenBar

Flexible and extensible UICollectionView based TokenBar written in Swift
Swift
136
star
74

bayesmark

Benchmark framework to easily compare Bayesian optimization methods on real machine learning tasks
Python
133
star
75

tchannel-java

A Java implementation of the TChannel protocol.
Java
132
star
76

android-template

This template provides a starting point for open source Android projects at Uber.
Java
128
star
77

crumb

An annotation processor for breadcrumbing metadata across compilation boundaries.
Kotlin
122
star
78

py-find-injection

Look for SQL injection attacks in python source code
Python
119
star
79

rides-java-sdk

Uber Rides Java SDK (beta)
Java
104
star
80

startup-reason-reporter

Reports the reason why an iOS App started.
Objective-C
97
star
81

uber-poet

A mock swift project generator & build runner to help benchmark various module dependency graphs.
Python
96
star
82

cadence-java-samples

Java
95
star
83

charlatan

A Python library to efficiently manage and install database fixtures
Python
89
star
84

simple-store

Simple yet performant asynchronous file storage for Android
Java
84
star
85

swift-abstract-class

Compile-time abstract class validation for Swift
Swift
84
star
86

tchannel-python

Python implementation of the TChannel protocol.
Python
76
star
87

client-platform-engineering

A collection of cookbooks, scripts and binaries used to manage our macOS, Ubuntu and Windows endpoints
Ruby
71
star
88

eight-track

Record and playback HTTP requests
JavaScript
70
star
89

lint-checks

A set of opinionated and useful lint checks
Kotlin
70
star
90

multidimensional_urlencode

Python library to urlencode a multidimensional dict
Python
67
star
91

uncaught-exception

Handle uncaught exceptions.
JavaScript
66
star
92

swift-common

Common code used by various Uber open source projects
Swift
66
star
93

uberscriptquery

UberScriptQuery, a SQL-like DSL to make writing Spark jobs super easy
Java
59
star
94

sentry-logger

A Sentry transport for Winston
JavaScript
56
star
95

graph.gl

WebGL2-Powered Visualization Components for Graph Visualization
JavaScript
53
star
96

nanoscope-art

C++
49
star
97

assume-role-cli

CLI for AssumeRole is a tool for running programs with temporary credentials from AWS's AssumeRole API.
Go
47
star
98

airlock

A prober to probe HTTP based backends for health
JavaScript
47
star
99

mutornadomon

Easy-to-install monitor endpoint for Tornado applications
Python
46
star
100

kafka-logger

A kafka logger for winston
JavaScript
45
star