• Stars
    star
    138
  • Rank 264,508 (Top 6 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created about 3 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

The official implementation of the paper, "SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning"

SubTab:

Author: Talip Ucar ([email protected])

The official implementation of the paper,

SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning

PWC

🔶 Note: The extended version of SubTab with codes and pre-processed data for Adult Income and BlogFeedback datasets can be found at: https://github.com/talipucar/SubTab_extended

Table of Contents:

  1. Model
  2. Environment
  3. Data
  4. Configuration
  5. Training and Evaluation
  6. Adding New Datasets
  7. Results
  8. Experiment tracking
  9. Citing the paper
  10. Citing this repo
NeurIPS 2021 slides NeurIPS 2021 poster
NeurIPS 2021 slides NeurIPS 2021 poster

Model

SubTab

Click for a slower version of the animation

SubTab

Environment

We used Python 3.7 for our experiments. The environment can be set up by following three steps:

pip install pipenv             # To install pipenv if you don't have it already
pipenv install --skip-lock     # To install required packages. 
pipenv shell                   # To activate virtual env

If the second step results in issues, you can install packages in Pipfile individually by using pip i.e. "pip install package_name".

Data

MNIST dataset is already provided to demo the framework. For your own dataset, follow the instructions in Adding New Datasets.

Configuration

There are two types of configuration files:

1. runtime.yaml
2. mnist.yaml
  1. runtime.yaml is a high-level configuration file used by all datasets to:

    • define the random seed
    • turn on/off mlflow (Default: False)
    • turn on/off python profiler (Default: False)
    • set data directory
    • set results directory
  2. Second configuration file is dataset-specific and is used to configure the architecture of the model, loss functions, and so on.

    • For example, we set up a configuration file for MNIST dataset with the same name. Please note that the name of the configuration file should be same as name of the dataset with all letters in lowercase.
    • We can have configuration files for other datasets such as tcga.yaml and income.yaml for tcga and income datasets respectively.

Training and Evaluation

You can train and evaluate the model by using:

python train.py # For training. 
python eval.py  # For evaluation
  • train.py will also run evaluation at the end of the training.
  • You can also run evaluation separately by using eval.py.
  • For a list of arguments, please see ./utils/arguments.py
    • Use -h argument to get help when running scripts.
    • Use -d dataset_name to run scripts on new datasets

Adding New Datasets

For each new dataset, you can use the following steps:

  1. Provide a _load_dataset_name() function, similar to MNIST load function

    • For example, you can add _load_tcga() for tcga dataset, or _load_income() for income dataset.
    • The function should return (x_train, y_train, x_test, y_test)
  2. Add a separate elif condition in this section within _load_data() method of TabularDataset() class in utils/load_data.py

  3. Create a new config file with the same name as dataset name.

    • For example, tcga.yaml for tcga dataset, or income.yaml for income dataset.

    • You can also duplicate one of the existing configuration files (e.g. mnist.yaml), and re-name it.

    • Make sure that the new config file is under config/ directory.

  4. Provide data folder with pre-processed training and test set, and place it under ./data/ directory. You can also do train-test split and pre-processing within your custom _load_dataset_name() function.

  5. (Optional) If you want to place the new dataset under a different directory than the local "./data/", then:

    • Place the dataset folder anywhere, and define the root directory to it in this line of /config/runtime.yaml.

    • For example, if the path to tcga dataset is /home/.../data/tcga/, you only need to include /home/.../data/ in runtime.yaml. The code will fill in tcga folder name from the name given in the command line argument (e.g. -d dataset_name. In this case, dataset_name would be tcga).

Structure of the repo

- train.py
- eval.py

- src
    |-model.py
    
- config
    |-runtime.yaml
    |-mnist.yaml
    
- utils
    |-load_data.py
    |-arguments.py
    |-model_utils.py
    |-loss_functions.py
    ...
    
- data
    |-mnist
    ...
    
- results
    |
    ...

Results

Results at the end of training is saved under ./results directory. Results directory structure is as following:

- results
    |-dataset name
            |-evaluation
                |-clusters (for plotting t-SNE and PCA plots of embeddings)
                |-reconstructions (not used)
            |-training
                |-model_mode (e.g. ae for autoencoder)   
                     |-model
                     |-plots
                     |-loss

You can save results of evaluations under "evaluation" folder.

Experiment tracking

MLFlow is used to track experiments. It is turned off by default, but can be turned on by changing option on this line in runtime config file in ./config/runtime.yaml

Citing the paper

@article{ucar2021subtab,
  title={SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning},
  author={Ucar, Talip and Hajiramezanali, Ehsan and Edwards, Lindsay},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  year={2021}
}

Citing this repo

If you use SubTab framework in your own studies, and work, please cite it by using the following:

@Misc{talip_ucar_2021_SubTab,
  author =   {Talip Ucar},
  title =    {{SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning}},
  howpublished = {\url{https://github.com/AstraZeneca/SubTab}},
  month        = June,
  year = {since 2021}
}

More Repositories

1

awesome-explainable-graph-reasoning

A collection of research papers and software related to explainability in graph machine learning.
1,947
star
2

chemicalx

A PyTorch and TorchDrug based deep learning library for drug pair scoring. (KDD 2022)
Python
711
star
3

rexmex

A general purpose recommender metrics library for fair evaluation.
Python
277
star
4

awesome-drug-discovery-knowledge-graphs

A collection of research papers, datasets and software related to knowledge graphs for drug discovery. Accompanies the paper "A review of biomedical datasets relating to drug discovery: a knowledge graph perspective" (Briefings in Bioinformatics, 2022)
196
star
5

awesome-shapley-value

Reading list for "The Shapley Value in Machine Learning" (JCAI 2022)
137
star
6

onto_merger

OntoMerger is an ontology alignment library for deduplicating knowledge graph nodes that represent the same domain.
HTML
90
star
7

awesome-drug-pair-scoring

Readings for "A Unified View of Relational Deep Learning for Drug Pair Scoring." (IJCAI 2022)
88
star
8

biology-for-ai

learning biology syllabus, geared for machine learning folks
78
star
9

KAZU

Fast, world class biomedical NER
Python
75
star
10

jazzy

Fast calculation of hydrogen-bond strengths and free energy of hydration of small molecules.
Python
71
star
11

judgyprophet

Forecasting for knowable future events using Bayesian informative priors (forecasting with judgmental-adjustment).
Python
56
star
12

kallisto

Efficiently calculate 3D-features for quantitative structure-activity relationship approaches.
Python
56
star
13

skywalkR

code for Gogleva et al manuscript
R
44
star
14

runnable

Runnable
Python
40
star
15

kgem-in-drug-discovery

Code to accompany the "Understanding the Performance of Knowledge Graph Embeddings in Drug Discovery" manuscript (Artificial Intelligence in the Life Sciences, 2022)
Python
30
star
16

StarGazer

StarGazer is a tool designed for rapidly assessing drug repositioning opportunities. It combines multi-source, multi-omics data with a novel target prioritization scoring system in an interactive Python-based Streamlit dashboard. StarGazer displays target prioritization scores for genes associated with 1844 phenotypic traits.
Python
28
star
17

Omicsfold

Multi-omics data normalisation, model fitting and visualisation.
R
22
star
18

VecNER

A library of tools for dictionary-based Named Entity Recognition (NER), based on word vector representations to expand dictionary terms.
Python
19
star
19

peptide-tools

Programs to calculate phys-chem properties of synthetic peptides and proteins: isoelectric point and extinction coefficients.
Python
18
star
20

data-science-python-course

Jupyter Notebook
17
star
21

napari-wsi

A plugin to read whole slide images within napari.
Python
15
star
22

biomedical-kg-topological-imbalance

Code to accompany the "Implications of Topological Imbalance for Representation Learning on Biomedical Knowledge Graphs" (Briefings in Bioinformatics, 2022)
Jupyter Notebook
15
star
23

ibd-interpret

We trained high performing open source models on image scans of tissue biopsies to predict endoscopic categories in inflammatory bowel disease. These predictive models can help us better understand the disease pathology and represent a step towards automated clinical recruitment strategies.
Python
13
star
24

roo

A package and environment manager for R
Python
11
star
25

NESS

Official implementation of "NESS: Node Embeddings from Static Subgraphs"
Python
9
star
26

detectIS

A pipeline to rapidly detect exogenous DNA integration sites using DNA or RNA paired-end sequencing data
Perl
9
star
27

skywalkR-graph-features

Example notebooks that illustrate how to generate knowledge-based features. Features can be used in a variety of ML models, including recommender systems.
Jupyter Notebook
9
star
28

multimodal-python-course

The purpose of the code is to facilitate a comprehensive understanding of multimodal data science applications within medical domain. The code serves to support the delivery of a cutting-edge workshop designed to introduce researchers to the rapidly evolving field of multimodal data science
Jupyter Notebook
7
star
29

UnlockingHeart

This repository accompanies our paper Unlocking the Heart Using Adaptive Locked Agnostic Networks and enables replication of the key results.
Python
6
star
30

Tendril

This repository contains R package code for calculating tendril plots.
R
6
star
31

CTELC-Patient-Attrition-Model

Clinical Trial Enrollment Life Cycle (CTELC) modeling project aims to leverage "industry-wide" data to understand key drivers and build predictive models. Patient attrition, also referred to as dropout or patient withdrawal, occurs when patients enrolled in a clinical trial either withdraw or are lost to follow-up by the clinical site and trial sponsor.
R
6
star
32

MVDA_exploration_tools

Multivariate data analysis (MVDA) exploration tool is a Python library utilizing the scikit-learn library for partial least squares (PLS) and principal components analysis (PCA).
Jupyter Notebook
4
star
33

Siamese-Regression-Pairing

Siamese Neural Networks for Regression: Similarity-Based Pairing and Uncertainty Quantification
Python
4
star
34

ConvCaps-DR

Tensorflow-Keras implementation of deep Convolutional Capsule Networks with Dynamic Routing algorithm
Python
4
star
35

hsqc_structure_elucidation

Implementation of the SGNN graph neural network for 1H and 13C NMR prediction and a tool for distinguishing different molecules based on HSQC simulations
Jupyter Notebook
4
star
36

Multimodal_NSCLC

multi-omics data integration helps improving patient survival prediction. We provide a pipeline allowing for early integration of multiple omics plus clinical modalities in order to predict patient survival for NSCLC. The pipeline utilizes autoencoders, and helps identify main driving factor in survival prediction
R
4
star
37

magnus-extensions

Extensions packages for magnus
Python
3
star
38

fragler

Python
3
star
39

maraca

R package for the creation of "maraca" plots
R
3
star
40

PatientSafetyKG

Python
3
star
41

SelfPAD

The official implementation of "Improving Antibody Humanness Prediction using Patent Data".
Python
3
star
42

MCPL

ICML 2024 submission - An image is worth multiple words: learning object level concepts using multi-concepts prompts learning.
Python
2
star
43

qscheck

An R library to perform assertions and decision on input arguments.
R
2
star
44

molecular-complexity

Python implementation of the molecular complexity metric described by Proudfoot 2017 (http://dx.doi.org/10.1016/j.bmcl.2017.03.008).
Python
2
star
45

Machine-Learning-for-Predicting-Targeted-Protein-Degradation

The code was developed for training diverse ML and DL models to predict PROTACs degradation. Data cleaning for two public datasets, PROTAC-DB and PROTACpedia, are also included. PROTACs are of high interest for all disease areas of AZ and thus predicting their degradation is of general interest.
Jupyter Notebook
2
star
46

OSPred

The OSPred tool offers interactive visualization of clinical trial end point correlations with reference to a large pool of historical NSCLC studies. Its focused capability has the potential to digitally transform and accelerate data-driven decision making as part of the drug development process. OSPred enables data scientists to rapidly visualize, analyze and validate the endpoint (PFS, ORR, OS) correlation hypothesis and to predict HR OS, which potentially could lead to faster and cost-effective NSCLC clinical trials. OSPred - A Digital Health Aid for Rapid Analysis of Early Endpoints (PFS, ORR) In clinical trials that assess novel therapeutic agents in patients with non-small-cell lung cancer (NSCLC), early endpoints (e.g. progression-free survival [PFS] and objective response rate) are often evaluated as indicators of biological drug activity, and are used as surrogate endpoints for overall survival (OS). A data set was compiled to investigate ascertain correlation trends between early endpoints (e.g. odds ratio [OR] for PFS at 6 months) and late endpoints (e.g. hazard ratio [HR] OS). The dataset was curated from multiple source databases, including ClinicalTrials.gov, PubMed and Citeline(TrialTrove). We applied a random-effects method for meta-analysis of prior RCT data to correlate a variety of estimates with the hazard ratio (HR) for OS and PFS. We performed meta-regression analyses across different data-strata, stratified by the mechanism of action as PD1/PDL1, EGFR, VEGFR, DNA and evaluated the correlation of trial-level, treatment effects between early (e.g. PFS) and late (e.g. OS) endpoints in NSCLC oncology trials.
R
2
star
47

dpp_imp

Improved clinical data imputation via classical and quantum determinantal point processes
Python
1
star
48

survextrap-excesshazards

Demonstration of excess hazard and excess hazard cure models for survival extrapolation
R
1
star
49

adhce

R
1
star
50

trim21-bioprotac

Bioinformatics data analyses - Fletcher A. et al., Nature Communications 2023, doi: 10.1038/s41467-023-42546-2
1
star
51

OCT_publication

This repository contains the source code for the image analysis of optical coherence tomography images, as stated in the publication of Volumetric wound healing by machine learning and optical coherence tomography in type 2 diabetes.
MATLAB
1
star
52

lung-tumour-mice-mri

Python
1
star
53

multitask_impute

Supplementary code for 'Deep Learning Imputation for Multi Task Learning'
Python
1
star