• Stars
    star
    375
  • Rank 114,096 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created about 2 years ago
  • Updated 4 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[ICML 2023] The official implementation of the paper "TabDDPM: Modelling Tabular Data with Diffusion Models"

TabDDPM: Modelling Tabular Data with Diffusion Models

This is the official code for our paper "TabDDPM: Modelling Tabular Data with Diffusion Models" (paper)

Setup the environment

  1. Install conda (just to manage the env).
  2. Run the following commands
    export REPO_DIR=/path/to/the/code
    cd $REPO_DIR
    
    conda create -n tddpm python=3.9.7
    conda activate tddpm
    
    pip install torch==1.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
    pip install -r requirements.txt
    
    # if the following commands do not succeed, update conda
    conda env config vars set PYTHONPATH=${PYTHONPATH}:${REPO_DIR}
    conda env config vars set PROJECT_DIR=${REPO_DIR}
    
    conda deactivate
    conda activate tddpm

Running the experiments

Here we describe the neccesary info for reproducing the experimental results.
Use agg_results.ipynb to print results for all dataset and all methods.

Datasets

We upload the datasets used in the paper with our train/val/test splits (link below). We do not impose additional restrictions to the original dataset licenses, the sources of the data are listed in the paper appendix.

You could load the datasets with the following commands:

conda activate tddpm
cd $PROJECT_DIR
wget "https://www.dropbox.com/s/rpckvcs3vx7j605/data.tar?dl=0" -O data.tar
tar -xvf data.tar

File structure

tab-ddpm/ -- implementation of the proposed method
tuned_models/ -- tuned hyperparameters of evaluation model (CatBoost or MLP)

All main scripts are in scripts/ folder:

  • scripts/pipeline.py are used to train, sample and eval TabDDPM using a given config
  • scripts/tune_ddpm.py -- tune hyperparameters of TabDDPM
  • scripts/eval_[catboost|mlp|simple].py -- evaluate synthetic data using a tuned evaluation model or simple models
  • scripts/eval_seeds.py -- eval using multiple sampling and multuple eval seeds
  • scripts/eval_seeds_simple.py -- eval using multiple sampling and multuple eval seeds (for simple models)
  • scripts/tune_evaluation_model.py -- tune hyperparameters of eval model (CatBoost or MLP)
  • scripts/resample_privacy.py -- privacy calculation

Experiments folder (exp/):

  • All results and synthetic data are stored in exp/[ds_name]/[exp_name]/ folder
  • exp/[ds_name]/config.toml is a base config for tuning TabDDPM
  • exp/[ds_name]/eval_[catboost|mlp].json stores results of evaluation (scripts/eval_seeds.py)

To understand the structure of config.toml file, read CONFIG_DESCRIPTION.md.

Baselines:

Examples

Run TabDDPM tuning.

Template and example (--eval_seeds is optional):

python scripts/tune_ddpm.py [ds_name] [train_size] synthetic [catboost|mlp] [exp_name] --eval_seeds
python scripts/tune_ddpm.py churn2 6500 synthetic catboost ddpm_tune --eval_seeds

Run TabDDPM pipeline.

Template and example (--train, --sample, --eval are optional):

python scripts/pipeline.py --config [path_to_your_config] --train --sample --eval
python scripts/pipeline.py --config exp/churn2/ddpm_cb_best/config.toml --train --sample

It takes approximately 7min to run the script above (NVIDIA GeForce RTX 2080 Ti).

Run evaluation over seeds
Before running evaluation, you have to train the model with the given hyperparameters (the example above).

Template and example:

python scripts/eval_seeds.py --config [path_to_your_config] [n_eval_seeds] [ddpm|smote|ctabgan|ctabgan-plus|tvae] synthetic [catboost|mlp] [n_sample_seeds]
python scripts/eval_seeds.py --config exp/churn2/ddpm_cb_best/config.toml 10 ddpm synthetic catboost 5

More Repositories

1

rtdl

Research on Tabular Deep Learning: Papers & Packages
Python
874
star
2

ddpm-segmentation

Label-Efficient Semantic Segmentation with Diffusion Models (ICLR'2022)
Python
657
star
3

rtdl-num-embeddings

(NeurIPS 2022) On Embeddings for Numerical Features in Tabular Deep Learning
Python
302
star
4

navigan

Navigating the GAN Parameter Space for Semantic Image Editing
Jupyter Notebook
296
star
5

tabular-dl-tabr

The implementation of "TabR: Unlocking the Power of Retrieval-Augmented Tabular Deep Learning"
Python
258
star
6

rtdl-revisiting-models

(NeurIPS 2021) Revisiting Deep Learning Models for Tabular Data
Python
206
star
7

swarm

Official code for "SWARM Parallelism: Training Large Models Can Be Surprisingly Communication-Efficient"
Python
123
star
8

DeDLOC

Official code for "Distributed Deep Learning in Open Collaborations" (NeurIPS 2021)
Jupyter Notebook
115
star
9

RuLeanALBERT

RuLeanALBERT is a pretrained masked language model for the Russian language that uses a memory-efficient architecture.
Python
90
star
10

heterophilous-graphs

A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress?
Python
89
star
11

invertible-cd

[NeurIPS'2024] Invertible Consistency Distillation for Text-Guided Image Editing in Around 7 Steps
Python
82
star
12

GBDT-uncertainty

Jupyter Notebook
51
star
13

graph-glove

PyTorch code for the EMNLP 2020 paper "Embedding Words in Non-Vector Space with Unsupervised Graph Learning"
Python
40
star
14

specexec

Python
38
star
15

tabred

A Benchmark of Tabular Machine Learning in-the-Wild with real-world industry-grade tabular datasets
Python
37
star
16

DVAR

Official implementation of "Is This Loss Informative? Faster Text-to-Image Customization by Tracking Objective Dynamics" (NeurIPS 2023)
Python
36
star
17

sparqling-queries

This repo in the implementation of EMNLP'21 paper "SPARQLing Database Queries from Intermediate Question Decompositions" by Irina Saparina, Anton Osokin
Python
34
star
18

moshpit-sgd

"Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices", official implementation
Jupyter Notebook
28
star
19

adaptive-diffusion

[CVPR'2024] Adaptive Teacher-Student Collaboration for Text-Conditional Diffusion Models
Python
28
star
20

gan-transfer

Supplementary code for "When, Why, and Which Pretrained GANs Are Useful?" (ICLR'22)
Jupyter Notebook
24
star
21

vqdm

Official repository for VQDM:Accurate Compression of Text-to-Image Diffusion Models via Vector Quantization paper
Python
18
star
22

btard

Code for the paper "Secure Distributed Training at Scale" (ICML 2022)
Python
14
star
23

structural-graph-shifts

Evaluating Robustness and Uncertainty of Graph Models Under Structural Distributional Shifts (NeurIPS'23)
Python
11
star
24

crosslingual_winograd

"It's All in the Heads" (Findings of ACL 2021), official implementation and data
Python
10
star
25

gan_vs_diff_sr

Does Diffusion Beat GAN in Image Super Resolution?
10
star
26

distill-nf

Code for the paper: Distilling the Knowledge from Conditional Normalizing Flows
Jupyter Notebook
9
star
27

classification-measures

Official implementation and data for 'Good Classification Measures and How to Find Them' (NeurIPS 2021)
Python
7
star
28

text-to-img-hypernymy

Official code for "Hypernymy Understanding Evaluation of Text-to-Image Models via WordNet Hierarchy"
Jupyter Notebook
6
star
29

tabm

TabM: Advancing Tabular Deep Learning With Parameter-Efficient Ensembling
Python
6
star
30

dnar

The implementation of "Discrete Neural Algorithmic Reasoning"
Python
6
star
31

learnable-init

Code for the paper: Discovering Weight Initializers with Meta-Learning
Jupyter Notebook
5
star
32

mind-your-format

Mind Your Format: Towards Consistent Evaluation of In-Context Learning Improvements
Jupyter Notebook
5
star
33

proxy-dirichlet-distillation

Implementation of "Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets" (NeurIPS 2021) and "Uncertainty Estimation in Autoregressive Structured Prediction" (ICLR 2021)
Python
4
star
34

tabgraphs

A new benchmark of meaningful tabular datasets with known graph structure
Python
3
star
35

msr

An official repository of "Multi-Sentence Resampling: A Simple Approach to Alleviate Dataset Length Bias and Beam-Search Degradation"
Python
2
star