• Stars
    star
    162
  • Rank 232,284 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 2 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

TabLLM: Few-shot Classification of Tabular Data with Large Language Models

Figure_Overview

This repository contains the code to reproduce the results of the paper TabLLM: Few-shot Classification of Tabular Data with Large Language Models by Stefan Hegselmann, Alejandro Buendia, Hunter Lang, Monica Agrawal, Xiaoyi Jiang, and David Sontag.

Update 10/30/2023: We Added Additional Instructions to the Readme

Since several issues were raised regarding the code, we decided to add some additional instructions to the readme. We now provide all steps to reproduce an entry of our final results table. Reproducing the remaining results mainly consists of changing the experimental parameters. Thanks for everyone who provided feedback!

Overview

Reproducing the main results consists of three steps:

  1. Creating textual serializations of the nine public tabular datasets
  2. Train and evaluate TabLLM (use code from t-few project) on serialized datasets
  3. Running the baseline models on the tabular datasets

We did not include the code to serialize and evaluate the private healthcare dataset due to privacy concerns. Also, code for some additional experiments is not included. Feel free to contact us if you have any questions concerning these experiments.

Setting the Correct Paths

TabLLM and the t-few project use the path /root/<project> by default and we will assume that you cloned both repositories to this location for this readme, i.e., /root/TabLLM for TabLLM and /root/t-few for the t-few repository. It is very likely that you have to adapt those paths for your own setup. The easiest way is to replace all occurrences of /root with your own path. When you get an error running the code, please ensure that you set all paths correctly.

Preparing the Environments

We used conda to create the necessary virtual environments. For the TabLLM environment, we used python 3.8:

conda create -n tabllm python==3.8
conda activate tabllm

Next, install the necessary requirements for TabLLM.

conda install numpy scipy pandas scikit-learn
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
pip install datasets transformers sentencepiece protobuf xgboost lightgbm tabpfn

If you want to run training and inference for TabLLM, you also have to setup the environment for t-few. You can follow their readme to setup the environment. We had some dependency issues when following their instructions. Here are the commands that worked for us (taken and adapted from their instructions):

conda create -n tfew python==3.7
conda activate tfew
pip install fsspec==2021.05.0
pip install --use-deprecated=legacy-resolver  -r requirements.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip install urllib3==1.26.6
pip install importlib-metadata==4.13.0
pip install scikit-learn

To ensure that the t-few project ist setup correctly, you can run the command given in their repository:

export HF_HOME=~/.cache/huggingface
CUDA_VISIBLE_DEVICES=0 python -m src.pl_train -c t03b.json+rte.json -k save_model=False exp_name=first_exp

The result of the experiment should be stored in /root/t-few/exp_out/first_exp.

1. Creating Serialized Datasets

To create a textual serialization for one of the tabular datasets execute the following script with additional optional arguments for a specific serialization type. This will create a folder with a huggingface dataset in datasets_serialized:

create_external_datasets.py --dataset (car|income|diabetes|heart|bank|blood|calhousing|creditg|jungle) (--list) (--list (--tabletotext|--t0serialization|--values|--permuted|--shuffled))

For the serialization Text GPT, we used a script querying the GPT-3 API with a row entry encoded as a list and the prompts given in the paper.

We provide the Text serializations in datasets_serialized. The other serializations are omitted here due to size constraints. The Text serialization achieved the best results in our experiments.

2. Train and Evaluate TabLLM on Serialized Datasets

We used the codebase of the t-few project for our experiments. We made some small modifications to their code to enable experiments with our custom datasets and templates. We included all changed files in the t-few folder and you have to copy them over.

cp /root/TabLLM/t-few/bin/few-shot-pretrained-100k.sh  /root/t-few/bin/
cp /root/TabLLM/t-few/configs/* /root/t-few/configs/
cp /root/TabLLM/t-few/src/models/EncoderDecoder.py /root/t-few/src/models/
cp /root/TabLLM/t-few/src/data/* /root/t-few/src/data/
cp /root/TabLLM/t-few/src/scripts/get_result_table.py /root/t-few/src/scripts/

Please, check that you also set the paths correctly for the t-few project. In particular, you should check /root/t-few/src/data/dataset_readers.py to ensure that DATASETS_OFFLINE in line 75 points to /root/TabLLM/datasets_serialized and yaml_dict = yaml.load(open(...)) in line 233 uses the path /root/TabLLM/templates/templates_.

The script /root/t-few/bin/few-shot-pretrained-100k.sh runs all our TabLLM experiments for the different serializations and stores them in /root/t-few/exp_out. To run the 4-shot heart experiment with the Text serialization using the T0-3B model, set the for-loops going over the different experimental settings in /root/t-few/bin/few-shot-pretrained-100k.sh to:

for model in 't03b'
do
  [...]
  for num_shot in 4
  do
    [...]
    for dataset in heart 
    do
      [...]
      for seed in 42 1024 0 1 32  # Keep this for-loop as it is
      do
        [...]
      done
    done
  done
done

Then, you can run the specified setup from the t-few folder /root/t-few via:

./bin/few-shot-pretrained-100k.sh

The result of the experiment should be stored in /root/t-few/exp_out/t03b_heart_numshot4_seed*. Note that we use no validation set, hence, in the code our test data is treated as validation (=pred) set. As a consequence, you can find the test performance for seed 42 in /root/t-few/exp_out/t03b_heart_numshot4_seed42_ia3_pretrained100k/dev_scores.json:

cat /root/t-few/exp_out/t03b_heart_numshot4_seed42_ia3_pretrained100k/dev_scores.json
{"AUC": 0.617825311942959, "PR": 0.6409831261754565, "micro_f1": 0.5869565217391305, "macro_f1": 0.5511042629686697, "accuracy": 0.5869565217391305, "num": 184, "num_steps": -1, "score_gt": 0.8486858865489131, "score_cand": 0.9136485224184783}

To collect the results of several runs, we slightly changed the /root/t-few/src/scripts/get_result_table.py script to report the mean AUC and standard deviation. For the above example, using the script looks as follows:

python /root/t-few/src/scripts/get_result_table.py -e t03b* -d heart
================================================================================
Find 5 experiments fit into t03b*
heart: 67.65 (12.87)
Save result to exp_out/summary.csv

This results corresponds to the entry "TabLLM (T0 3B + Text Template)" for the heart dataset for 4 training examples (shots) on page 21 in our paper. To obtain the other experiments you have to adapt /root/t-few/bin/few-shot-pretrained-100k.sh accordingly. For more information, please also consider the original t-few repository or raise an issue.

3. Running the Baseline Models

We tested TabLLM against several baselines. They use the standard non-serialized datasets. The hyperparameter ranges are given in the paper. You can specify the baseline models and datasets that you want to run in the code. To run a baseline model execute

evaluate_external_datasets.py

We hope these instructions help you to reproduce our results. Feel free to contact us if you have any questions!

Citation

If you want to cite our work please use:

@inproceedings{hegselmann2023tabllm,
  title={Tabllm: Few-shot classification of tabular data with large language models},
  author={Hegselmann, Stefan and Buendia, Alejandro and Lang, Hunter and Agrawal, Monica and Jiang, Xiaoyi and Sontag, David},
  booktitle={International Conference on Artificial Intelligence and Statistics},
  pages={5549--5581},
  year={2023},
  organization={PMLR}
}

We use the code of

@article{liu2022few,
  title={Few-shot parameter-efficient fine-tuning is better and cheaper than in-context learning},
  author={Liu, Haokun and Tam, Derek and Muqeeth, Mohammed and Mohta, Jay and Huang, Tenghao and Bansal, Mohit and Raffel, Colin A},
  journal={Advances in Neural Information Processing Systems},
  volume={35},
  pages={1950--1965},
  year={2022}
}
@inproceedings{bach2022promptsource,
  title={PromptSource: An Integrated Development Environment and Repository for Natural Language Prompts},
  author={Bach, Stephen and Sanh, Victor and Yong, Zheng Xin and Webson, Albert and Raffel, Colin and Nayak, Nihal V and Sharma, Abheesht and Kim, Taewoon and Bari, M Saiful and F{\'e}vry, Thibault and others},
  booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics: System Demonstrations},
  pages={93--104},
  year={2022}
}

More Repositories

1

cfrnet

Counterfactual Regression
Python
261
star
2

structuredinference

Structured Inference Networks for Nonlinear State Space Models
Jupyter Notebook
255
star
3

embeddings

Code for AMIA CRI 2016 paper "Learning Low-Dimensional Representations of Medical Concepts"
Python
233
star
4

dmm

Deep Markov Models
Jupyter Notebook
127
star
5

deepDiagnosis

A torch package for learning diagnosis models from temporal patient data.
Lua
110
star
6

HealthKnowledgeGraph

Health knowledge graph for 157 diseases and 491 symptoms, learned from >270,000 patients' data
96
star
7

co-llm

Co-LLM: Learning to Decode Collaboratively with Multiple Language Models
Python
87
star
8

omop-learn

Python package for machine learning for healthcare using a OMOP common data model
Python
86
star
9

prancer

Platform enabling Rapid Annotation for Clinical Entity Recognition
JavaScript
48
star
10

gumbel-max-scm

Code for "Counterfactual Off-Policy Evaluation with Gumbel-Max Structural Causal Models" (ICML 2019)
Python
39
star
11

ML-tools

Miscellaneous tools for clinical ML
Python
30
star
12

human_ai_deferral

Human-AI Deferral Evaluation Benchmark (Learning to Defer) AISTATS23
Python
18
star
13

anchorExplorer

Python
17
star
14

trajectory-inspection

Code for "Trajectory Inspection: A Method for Iterative Clinician-Driven Design of Reinforcement Learning Studies"
Jupyter Notebook
16
star
15

cotrain-prompting

Code for co-training large language models (e.g. T0) with smaller ones (e.g. BERT) to boost few-shot performance
Python
15
star
16

ContextualAutocomplete_MLHC2020

Code for Contextual Autocomplete paper published in MLHC2020
Jupyter Notebook
13
star
17

realhumaneval

Jupyter Notebook
12
star
18

teaching-to-understand-ai

Code and webpages for our study on teaching humans to defer to an AI
Jupyter Notebook
11
star
19

dgm

Deep Generative Model (Torch)
Lua
11
star
20

learn-to-defer

Code for "Consistent Estimators for Learning to Defer to an Expert" (ICML 2020)
Jupyter Notebook
11
star
21

sc-foundation-eval

Code for evaluating single cell foundation models scBERT and scGPT
Jupyter Notebook
10
star
22

SparsityBoost

http://cs.nyu.edu/~dsontag/papers/BrennerSontag_uai13.pdf
Python
10
star
23

proxy-anchor-regression

Code for ICML 2021 paper "Regularizing towards Causal Invariance: Linear Models with Proxies" (ICML 2021)
Jupyter Notebook
10
star
24

onboarding_human_ai

Onboarding Humans to work with AI: Algorithms to find regions and describe them in natural language that show how humans should collaborate with AI (NeurIPS23)
Jupyter Notebook
10
star
25

vae_ssl

Scalable semi-supervised learning with deep variational autoencoders
Jupyter Notebook
9
star
26

amr-uti-stm

Code for "A decision algorithm to promote outpatient antimicrobial stewardship for uncomplicated urinary tract infection"
Python
8
star
27

dgc_predict

Applies and evaluates a variety of methods to complete a partially-observed data tensor, e.g. comprising gene expression profiles corresponding to various drugs, applied in various cellular contexts.
R
8
star
28

mimic-language-model

A conditional language model for MIMIC-III.
Python
8
star
29

ml_mmrf

Machine Learning with data from the Multiple Myeloma Research Foundation
Jupyter Notebook
7
star
30

overparam

Python
6
star
31

ckd_progression

Python
6
star
32

parametric-robustness-evaluation

Code for paper "Evaluating Robustness to Dataset Shift via Parametric Robustness Sets"
Python
5
star
33

active_learn_to_defer

Code for Sample Efficient Learning of Predictors that Complement Humans (ICML 2022)
Python
5
star
34

surprising-sepsis

Python
4
star
35

large-scale-temporal-shift-study

Code for Large-Scale Study of Temporal Shift in Health Insurance Claims. Christina X Ji, Ahmed M Alaa, David Sontag. CHIL, 2023. https://arxiv.org/abs/2305.05087
Python
4
star
36

amr-uti-kdd

Treatment Policy Learning in Multiobjective Settings with Fully Observed Outcomes (KDD 2020)
Python
4
star
37

theanomodels

A lightweight wrapper around theano for rapid-prototyping of models
Python
3
star
38

clinical-anchors

Python
3
star
39

finding-decision-heterogeneity-regions

Code for "Finding Regions of Heterogeneity in Decision-Making via Expected Conditional Covariance" at NeurIPS 2021
Jupyter Notebook
3
star
40

fully-observed-policy-learning

Code for "Treatment Policy Learning in Multiobjective Settings with Fully Observed Outcomes" (KDD 2020)
Jupyter Notebook
3
star
41

mimic_annotations

2
star
42

fw-inference

Barrier Frank-Wolfe for Marginal Inference
C++
2
star
43

oncology_rationale_extraction

Functionality from "Automated NLP extraction of clinical rationale for treatment discontinuation in breast cancer"
Python
2
star
44

overlap-code

Code for "Characterization of Overlap in Observational Studies" (AISTATS 2020)
Python
2
star
45

omop-variation

Tools to identify and evaluate heterogeneity in decision-making processes.
Python
2
star
46

clinicalml-scBERT-NMI

analysis code to reproduce results in NMI submission
Jupyter Notebook
1
star
47

rct-obs-extrapolation

Code for paper, "Falsification before Extrapolation in Causal Effect Estimation"
Jupyter Notebook
1
star