• Stars
    star
    236
  • Rank 170,480 (Top 4 %)
  • Language
    Python
  • License
    BSD 3-Clause "New...
  • Created over 2 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Official Pytorch code for Structure-Aware Transformer.

Structure-Aware Transformer

Updates: We have added the script for model visualization (Figure 4 in our paper)!

The repository implements the Structure-Aware Transformer (SAT) in Pytorch Geometric described in the following paper

Dexiong Chen*, Leslie O'Bray*, and Karsten Borgwardt. Structure-Aware Transformer for Graph Representation Learning. ICML 2022.
*Equal contribution

TL;DR: A class of simple and flexible graph transformers built upon a new self-attention mechanism, which incorporates structural information into the original self-attention by extracting a subgraph representation rooted at each node before computing the attention. Our structure-aware framework can leverage any existing GNN to extract the subgraph representation and systematically improve the peroformance relative to the base GNN.

Citation

Please use the following to cite our work:

@InProceedings{Chen22a,
	author = {Dexiong Chen and Leslie O'Bray and Karsten Borgwardt},
	title = {Structure-Aware Transformer for Graph Representation Learning},
	year = {2022},
	booktitle = {Proceedings of the 39th International Conference on Machine Learning~(ICML)},
	series = {Proceedings of Machine Learning Research}
}

A short description of SAT

SAT vs the vanilla Transformer

SAT vs Transformer

The SAT architecture compared with the vanilla transformer architecture is shown above. We make the self-attention calculation in each transformer layer structure-aware by leveraging structure-aware node embeddings. We generate these embeddings using a structure extractor (for example, any GNN) on the $k$-hop subgraphs centered at each node of interest. Then, the updated node embeddings are used to compute the query ($\mathbf{Q}$) and key ($\mathbf{K}$) matrices. We provide example structure extractors in the next figure.

Example structure extractors

Overview figure

The figure above shows the two example structure extractors used in our paper ($k$-subtree and $k$-subgraph). Structure-aware node representations are generated in the $k$-subtree GNN extractor by using the $k$-hop subtree centered at each node (here, $k=1$) and using a GNN to generate updated node representations. The explicit extraction of the subtree as an initial step is not strictly necessary, as a GNN by nature will use the $k$-hop subtree and generate updated node embeddings using the subtree information. For the $k$-subgraph GNN extractor, we first extract the $k$-hop subgraph centered at each node, and then use a GNN on each subgraph to generate node representations using the full subgraph information. The updated node embeddings are then used to compute the query ($\mathbf{Q}$) and key ($\mathbf{K}$) matrices shown in the first figure.

A quick-start example

Below you can find a quick-start example on the ZINC dataset, see ./experiments/train_zinc.py for more details.

click to see the example:
import torch
from torch_geometric import datasets
from torch_geometric.loader import DataLoader
from sat.data import GraphDataset
from sat import GraphTransformer

# Load the ZINC dataset using our wrapper GraphDataset,
# which automatically creates the fully connected graph.
# For datasets with large graph, we recommend setting return_complete_index=False
# leading to faster computation
dset = datasets.ZINC('./datasets/ZINC', subset=True, split='train')
dset = GraphDataset(dset)

# Create a PyG data loader
train_loader = DataLoader(dset, batch_size=16, shuffle=True)

# Create a SAT model
dim_hidden = 16
gnn_type = 'gcn' # use GCN as the structure extractor
k_hop = 2 # use a 2-layer GCN

model = GraphTransformer(
    in_size=28, # number of node labels for ZINC
    num_class=1, # regression task
    d_model=dim_hidden,
    dim_feedforward=2 * dim_hidden,
    num_layers=2,
    batch_norm=True,
    gnn_type='gcn', # use GCN as the structure extractor
    use_edge_attr=True,
    num_edge_features=4, # number of edge labels
    edge_dim=dim_hidden,
    k_hop=k_hop,
    se='gnn', # we use the k-subtree structure extractor
    global_pool='add'
)

for data in train_loader:
    output = model(data) # batch_size x 1
    break

Installation

The dependencies are managed by miniconda

python=3.9
numpy
scipy
pytorch=1.9.1
pytorch-geometric=2.0.2
einops
ogb

Once you have activated the environment and installed all dependencies, run:

source s

Datasets will be downloaded via Pytorch geometric and OGB package.

Train SAT on graph and node prediction datasets

All our experimental scripts are in the folder experiments. So to start with, after having run source s, run cd experiments. The hyperparameters used below are selected as optimal

Graph regression on ZINC dataset

Train a k-subtree SAT with PNA:

python train_zinc.py --abs-pe rw --se gnn --gnn-type pna2 --dropout 0.3 --k-hop 3 --use-edge-attr

Train a k-subgraph SAT with PNA

python train_zinc.py --abs-pe rw --se khopgnn --gnn-type pna2 --dropout 0.2 --k-hop 3 --use-edge-attr

Node classification on PATTERN and CLUSTER datasets

Train a k-subtree SAT on PATTERN:

python train_SBMs.py --dataset PATTERN --weight-class --abs-pe rw --abs-pe-dim 7 --se gnn --gnn-type pna3 --dropout 0.2 --k-hop 3 --num-layers 6 --lr 0.0003

and on CLUSTER:

python train_SBMs.py --dataset CLUSTER --weight-class --abs-pe rw --abs-pe-dim 3 --se gnn --gnn-type pna2 --dropout 0.4 --k-hop 3 --num-layers 16 --dim-hidden 48 --lr 0.0005

Graph classification on OGB datasets

--gnn-type can be gcn, gine or pna, where pna obtains the best performance.

# Train SAT on OGBG-PPA
python train_ppa.py --gnn-type gcn --use-edge-attr
# Train SAT on OGBG-CODE2
python train_code2.py --gnn-type gcn --use-edge-attr

Model visualization

We showcase here how to visualize the attention weights of the [CLS] node learned by SAT and vanilla Transformer with the random walk positional encoding. We have provided the pre-trained models on the Mutagenecity dataset. To visualize the pre-trained models, you need to install the networkx and matplotlib packages, then run:

python model_visu.py --graph-idx 2003

This will generate the following image, the same as the Figure 4 in our paper:

Model_interpretation

More Repositories

1

topological-autoencoders

Code for the paper "Topological Autoencoders" by Michael Moor, Max Horn, Bastian Rieck, and Karsten Borgwardt.
Python
137
star
2

Set_Functions_for_Time_Series

Repository of the ICML 2020 paper "Set Functions for Time Series"
Python
120
star
3

TOGL

Topological Graph Neural Networks (ICLR 2022)
Python
112
star
4

GraphKernels

A package for computing Graph Kernels
C++
100
star
5

proteinshake

Protein structure datasets for machine learning.
Python
99
star
6

WWL

Wasserstein Weisfeiler-Lehman Graph Kernels
Python
77
star
7

mgp-tcn

Sepsis Prediction on MIMIC
Python
66
star
8

P-WL

A Persistent Weisfeiler–Lehman Procedure for Graph Classification
Python
60
star
9

graph-kernels

Graph kernels
C++
56
star
10

S3M

A software package for statistically significant shapelet mining
C++
52
star
11

sampling-outlier-detection

Rapid computation of distance-based outlierness scores via sampling
C
33
star
12

maldi_amr

Code for the paper "Antimicrobial resistance prediction in clinical isolates through machine learning on MALDI-TOF mass spectra"
Jupyter Notebook
30
star
13

PST

Protein Structure Transformer (PST): Endowing pretrained protein language models with structural knowledge
Python
30
star
14

Neural-Persistence

Code for the paper 'Neural Persistence: A Complexity Measure for Deep Neural Networks Using Algebraic Topology'
Python
29
star
15

maldi-learn

Software library for Maldi-Tof preprocessing and machine learning analysis.
Python
28
star
16

PyChange

Multiple change detection with python
C++
22
star
17

WTK

A Wasserstein Subsequence Kernel for Time Series.
Python
21
star
18

JointMDS

Official implementation of Joint Multidimensional Scaling
Python
21
star
19

fMRI_Cubical_Persistence

Code of our NeurIPS 2020 publication 'Uncovering the Topology of Time-Varying fMRI Data using Cubical Persistence'
Python
20
star
20

filtration_curves

Code for the KDD 2021 paper 'Filtration Curves for Graph Representation'
Python
18
star
21

fisher_information_embedding

Official code for Fisher information embedding for node and graph learning (ICML 2023)
Python
17
star
22

NeuralWalker

Official Pytorch implementation of NeuralWalker
Python
17
star
23

MvKDR

Multi-view Spectral Clustering on Conflicting Views
Python
15
star
24

Kernelized-Rank-Learning

Kernelized rank learning for personalized drug recommendation
Jupyter Notebook
15
star
25

ggme

Official repository for the ICLR 2022 paper "Evaluation Metrics for Graph Generative Models: Problems, Pitfalls, and Practical Solutions" https://openreview.net/forum?id=tBtoZYKd9n
Python
14
star
26

HOGImine

Higher-order genetic interaction discovery with network-based biological priors.
C++
13
star
27

MotiFiesta

Approximate subgraph motif mining through learnable edge contraction.
Python
12
star
28

networkGWAS

Method for performing genome-wide association like studies on neighborhoods identified on biological networks relevant for the phenotype of interest.
Python
12
star
29

multicenter-sepsis

Python
11
star
30

MID

MID (Mutual Information Dimension) for measuring statistical dependence between two random variables
C
11
star
31

graphkernels-review

Code and data sets for the review on graph kernels
Python
10
star
32

reComBat

reComBat package to correct batch effects
Jupyter Notebook
10
star
33

maldi_PIKE

Code for 'Topological and kernel-based microbial phenotype prediction from MALDI-TOF mass spectra'
Python
9
star
34

proteinshake_models

Python
8
star
35

topo-ae-distances

Python
8
star
36

CAsMap

Detection of statistically significant combinations of SNPs in association mapping
C++
7
star
37

significant-subgraph-mining

Finding statistically significant subgraphs while correcting for multiple testing
C++
7
star
38

Kernel-Conditional-Clustering

Kernel Conditional Clustering
MATLAB
7
star
39

ADNI_3DCNNvsTDA

This is a summary of the model code used in "Back to the basics with inclusion of clinical domain knowledge - A simple, scalable and effective model of Alzheimer's Disease classification". It comprised the relevant 3D CNNs for hippocampus, patch and full inner brain image subsets, the TDA 2D CNN with relevant dense models to combine models trained on persistence images from different homological dimensions. Moreover, the models (GNN and LR) to combine multiple image patches are included, as well as the data splits in terms of ADNI database patient IDs (partitions).
Python
7
star
40

2019-06-Machine-Learning-for-Biology

Introduction to Machine Learning for Biology (Workshop @ D-BSSE Retreat 2019)
TeX
7
star
41

SiNIMin

Significant Network Interval Mining
C++
6
star
42

LSH-WTK

Locality-Sensitive Hashing for the Wasserstein Time Series Kernel.
Python
6
star
43

LMM-Lasso

An implementation of the Lasso model for association mapping and phenotype prediction which corrects for population strucure (Rakitsch et al., Bioinformatics 2013): http://goo.gl/FRmXwI
Python
6
star
44

sepsis-prediction-review

Python
5
star
45

uea_ucr_datasets

A small package for loading and handling UEA UCR time series classification datasets.
Python
5
star
46

GP-PoM

TypeScript
5
star
47

graphhopper-kernels

Scalable kernels for graphs with continuous attributes (Feragen et al., NIPS 2013) http://goo.gl/VxSfzZ
MATLAB
5
star
48

ARDISS

Automatic Relevance Determination for Imputation of Summary Statistics
Python
4
star
49

Topf

Topological peak filtering
Python
4
star
50

Epistasis-GLIDE

A C and CUDA implementation of tabulating linear regression for an exhaustive pairwise interaction search on a CUDA enabled GPU (Kam-Thong et al., Human Heredity 2012) http://goo.gl/XE54ir
Cuda
4
star
51

simbsig

The official implementation for the SIMBSIG package
Python
3
star
52

FindComb

Scientifica app for finding the most significant combinations of features
Python
3
star
53

Multi-SConES

Multi-task feature selection coupled with multiple network regularizers
R
3
star
54

Imputing_Signatures

Python
2
star
55

PheGeMIL

Python
2
star
56

ccSVM

Confounder-corrected Classification with Support Vector Machines (Li et al., Bioinformatics 2011) http://goo.gl/Qz9Ap5
C++
2
star
57

DeepEST

Python
2
star
58

proteinshake_release

Python
2
star
59

MODS-recovery

Codes for prediction of the recovery of pediatric sepsis patients with MODS
Python
2
star
60

GraphMatchingSubstitutionMatrices

Code and Data for the paper: Structure- and Function-Aware Substitution Matrices via Learnable Graph Matching (RECOMB 2024 & ICML 2024 Differentiable Almost Everything Workshop)
Python
2
star
61

SignificantPatternMiningFDR

Code and Data for the paper: FASM and FAST-YB: Significant Pattern Mining with False Discovery Rate Control (ICDM 2023).
C
2
star
62

sc-autoencoding

student internship of Simon Streib to reduce single cell data
Python
2
star
63

Epistasis-lightbulb

Efficient algorithms and GPU implementations for genome-wide epistasis screens as described in (Achlioptas et al., KDD 2011) http://goo.gl/jX8kPi
Python
2
star
64

batchCorrectionPublicData

Summary of the code published in 'reComBat: Batch effect removal in large-scale, multi-source omics data integration'.
Python
1
star
65

biobank_genomics

Shell
1
star
66

LongCOVID

Prediction of long COVID from proteomic and clinical data
Python
1
star
67

gene-representations-in-networks

A systematic evaluation of gene representations in network based genetic analysis
Python
1
star
68

homebrew-mlcb

Homebrew taps of the Machine Learning and Computational Biology group of Prof. Karsten Borgwardt
Ruby
1
star
69

glide-scripts

Companion scripts to the GLIDE software
Python
1
star
70

SCIRecoveryPredictionPublic

We provide a simple matching algorithm to identify digital twins for spinal cord injury patients in the acute injury phase.
Python
1
star
71

MoProEmbeddings

Implementation of moment propagation embeddings.
Python
1
star