• Stars
    star
    157
  • Rank 237,051 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created over 3 years ago
  • Updated about 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of "GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings" in PyTorch

PyGAS: Auto-Scaling GNNs in PyG


PyGAS is the practical realization of our GNNAutoScale (GAS) framework, which scales arbitrary message-passing GNNs to large graphs, as described in our paper:

Matthias Fey, Jan E. Lenssen, Frank Weichert, Jure Leskovec: GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings (ICML 2021)

GAS prunes entire sub-trees of the computation graph by utilizing historical embeddings from prior training iterations, leading to constant GPU memory consumption in respect to input mini-batch size, and maximally expressivity.

PyGAS is implemented in PyTorch and utilizes the PyTorch Geometric (PyG) library. It provides an easy-to-use interface to convert a common or custom GNN from PyG into its scalable variant:

from torch_geometric.nn import SAGEConv
from torch_geometric_autoscale import ScalableGNN
from torch_geometric_autoscale import metis, permute, SubgraphLoader

class GNN(ScalableGNN):
    def __init__(self, num_nodes, in_channels, hidden_channels,
                 out_channels, num_layers):
        # * pool_size determines the number of pinned CPU buffers
        # * buffer_size determines the size of pinned CPU buffers,
        #   i.e. the maximum number of out-of-mini-batch nodes

        super().__init__(num_nodes, hidden_channels, num_layers,
                         pool_size=2, buffer_size=5000)

        self.convs = ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, adj_t, *args):
        for conv, history in zip(self.convs[:-1], self.histories):
            x = conv(x, adj_t).relu_()
            x = self.push_and_pull(history, x, *args)
        return self.convs[-1](x, adj_t)

perm, ptr = metis(data.adj_t, num_parts=40, log=True)
data = permute(data, perm, log=True)
loader = SubgraphLoader(data, ptr, batch_size=10, shuffle=True)

model = GNN(...)
for batch, *args in loader:
    out = model(batch.x, batch.adj_t, *args)

A detailed description of ScalableGNN can be found in its implementation.

Requirements

pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-geometric

where ${TORCH} should be replaced by either 1.7.0 or 1.8.0, and ${CUDA} should be replaced by either cpu, cu92, cu101, cu102, cu110 or cu111, depending on your PyTorch installation.

Installation

pip install git+https://github.com/rusty1s/pyg_autoscale.git

or

python setup.py install

Project Structure

  • torch_geometric_autoscale/ contains the source code of PyGAS
  • examples/ contains examples to demonstrate how to apply GAS in practice
  • small_benchmark/ includes experiments to evaluate GAS performance on small-scale graphs
  • large_benchmark/ includes experiments to evaluate GAS performance on large-scale graphs

We use Hydra to manage hyperparameter configurations.

Cite

Please cite our paper if you use this code in your own work:

@inproceedings{Fey/etal/2021,
  title={{GNNAutoScale}: Scalable and Expressive Graph Neural Networks via Historical Embeddings},
  author={Fey, M. and Lenssen, J. E. and Weichert, F. and Leskovec, J.},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2021},
}

More Repositories

1

pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations
Python
1,531
star
2

pytorch_sparse

PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations
Python
990
star
3

pytorch_cluster

PyTorch Extension Library of Optimized Graph Cluster Algorithms
C++
798
star
4

deep-graph-matching-consensus

Implementation of "Deep Graph Matching Consensus" in PyTorch
Python
256
star
5

pytorch_spline_conv

Implementation of the Spline-Based Convolution Operator of SplineCNN in PyTorch
C++
169
star
6

table2excel

Convert and download html tables to a xlsx-file that can be opened in Microsoft Excel
JavaScript
112
star
7

deep-learning-cheatsheet

TeX
92
star
8

embedded_gcnn

Embedded Graph Convolutional Neural Networks (EGCNN) in TensorFlow
Jupyter Notebook
78
star
9

himp-gnn

Hierarchical Inter-Message Passing for Learning on Molecular Graphs
Python
75
star
10

koa2-rest-api

ES6 RESTFul Koa2 API with Mongoose and OAuth2
JavaScript
75
star
11

graph-based-image-classification

Implementation of Planar Graph Convolutional Networks in TensorFlow
Python
43
star
12

pytorch_unique

PyTorch Extension Library of Optimized Unique Operation
Python
37
star
13

deep-learning-on-graphs

TeX
31
star
14

mongoose-i18n-localize

Mongoose plugin to support i18n and localization
JavaScript
22
star
15

dotfiles

Shell
18
star
16

RSClipperWrapper

A small and simple wrapper for the Clipper library to perform polygon clipping (Swift)
C++
17
star
17

RSShapeNode

A RSShapeNode object draws a shape by rendering a Core Graphics path offscreen using a disconnected CAShapeLayer and snapshots the image to a SKSpriteNode (Swift)
Swift
8
star
18

rusty1s.github.io

HTML
6
star
19

pytorch_bincount

Python
6
star
20

vim-happy-hacking

Vim Script
5
star
21

rusty1s

4
star
22

table-select

Allows you to select table row elements like in your standard finder environment
JavaScript
3
star
23

DigDeeper

the Mining / Crafting / Trading game (Swift 2.0)
C++
3
star
24

react-pattern-library

React Pattern Library for various UI components
JavaScript
3
star
25

mongoose-i18n-error

lightweight module for node.js/express.js to create beautiful mongoose i18n validation error messages
JavaScript
2
star
26

react-dev-config

Customizable Configuration for modern React apps
JavaScript
2
star
27

mongoose-integer

mongoose plugin to validate integer values within a Mongoose Schema
JavaScript
2
star
28

hyper-happy-hacking

JavaScript
1
star
29

RSRoundBorderedButton

Round bordered Button like the ones used in the Apple AppStore (Swift)
Swift
1
star
30

ComputationOffloading

Energieeffizienz durch Computation Offloading in der Cloud
1
star
31

react-documentviewer

React Documentviewer for various mimetypes
JavaScript
1
star
32

RSRandomPolygon

Swift
1
star
33

tensorflow-graph-plugin

Python
1
star
34

dependent-select-boxes

Allows a child select box to change its options dependent on its parent select box
JavaScript
1
star
35

texture-synthesis

TeX
1
star
36

RSScene

An inheritance of SKScene that adds a game logic loop to the runtime of a scene (Swift)
Swift
1
star
37

OCF-andCP-Networks

Qualitative Semantiken fĂźr DAGs - ein Vergleich von OCF- und CP-Netzwerken
1
star
38

js-dev-utils

JavaScript
1
star
39

RSContactGrid

A triangular/square/rotated square/hexagonal grid tile map with contact detection for any path (Swift 2.0)
Swift
1
star