Amortized Causal Discovery
This repo contains the official PyTorch implementation of:
Sindy Lรถwe*, David Madras*, Richard Zemel, Max Welling - Amortized Causal Discovery: Learning to Infer Causal Graphs from Time-Series Data
With Amortized Causal Discovery we learn to infer causal relations from samples with different underlying causal graphs but shared dynamics. This enables us to generalize across samples and thus improve our performance with increasing training data size.
*equal contribution
What is Amortized Causal Discovery?
With Amortized Causal Discovery, we separate causal relation prediction from dynamics modelling. Our amortized encoder learns to infer causal relations across samples with different underlying graphs. Our decoder learns to model the shared dynamics of the predicted relations.
This separation allows us to train a joint model for samples with different underlying causal graphs. This is in contrast to previous approaches, which need to refit a new model whenever they encounter samples with a different underlying causal graph.
What we found exciting is that this allows us to achieve tremendous improvements in causal inference performance with increasing training data size. Amortized Causal Discovery (ACD) manages to outperform previous causal discovery approaches with as little as 50 training samples; with 50.000 samples it outperforms them by more than 30% points.
How to run the code
Dependencies
-
Setup the conda environment
ACD
by running:bash setup_dependencies.sh
If you want to make use of your GPU, you might have to install a cuda-enabled pytorch version manually. Use the appropriate command provided here to achieve this.
-
Don't forget to activate the environment and cd into the codebase directory when playing with the code later on
source activate ACD cd codebase
Datasets
-
To generate the particles with springs dataset from our paper, run
python -m data.generate_dataset
-
To generate a particles dataset with varying latent temperature, run
python -m data.generate_dataset --temperature_dist --temperature_alpha 2 --temperature_num_cats 3
-
To generate the Kuramoto dataset from our paper, run
python -m data.generate_ODE_dataset
-
The Netsim dataset is available here
Experiments
-
Run the Springs experiment by running
python -m train --suffix _springs5
the Kuramoto experiment with
python -m train --suffix _kuramoto5 --encoder cnn
and the Netsim experiment with
python -m train --suffix netsim
-
To run the experiment with an unobserved temperature variable, run
python -m train --suffix _springs5 --encoder cnn --decoder sim --global_temp --load_temperatures
-
To run the experiment with an unobserved time-series, run
python -m train --suffix _springs5 --unobserved 1
-
View all possible command-line options by running
python -m train --help
Cite
Please cite our paper if you use this code in your own work:
@article{lowe2022amortized,
title={Amortized Causal Discovery: Learning to Infer Causal Graphs from Time-Series Data},
author={L{\"o}we, Sindy and Madras, David and Zemel, Richard, and Welling, Max},
journal={Causal Learning and Reasoning (CLeaR)},
year={2022}
}
References
Acknowledgements
The Robert Bosch GmbH is acknowledged for financial support.