• Stars
    star
    121
  • Rank 293,924 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created over 1 year ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Official implementation of A* Networks

A*Net: A* Networks

This is the official codebase of the paper

A*Net: A Scalable Path-based Reasoning Approach for Knowledge Graphs

Zhaocheng Zhu*, Xinyu Yuan*, Mikhail Galkin, Sophie Xhonneux, Ming Zhang, Maxime Gazeau, Jian Tang

Overview

A*Net is a scalable path-based method for knowledge graph reasoning. Inspired by the classical A* algorithm, A*Net learns a neural priority function to select important nodes and edges at each iteration, which significantly reduces time and memory footprint for both training and inference.

A*Net is the first path-based method that scales to ogbl-wikikg2 (2.5M entities, 16M triplets). It also enjoys the advantages of path-based methods such as inductive capacity and interpretability.

Here is a demo of A*Net with a ChatGPT interface. By reasoning on the Wikidata knowledge graph, ChatGPT produces more grounded predictions and less hallucination.

A*Net with ChatGPT interface

astarnet.illustration.mp4

This codebase contains implementation for A*Net and its predecessor NBFNet.

Installation

The dependencies can be installed via either conda or pip. A*Net is compatible with 3.7 <= Python <= 3.10 and PyTorch >= 1.13.0.

From Conda

conda install pytorch cudatoolkit torchdrug pytorch-sparse -c pytorch -c pyg -c milagraph
conda install ogb easydict pyyaml openai -c conda-forge

From Pip

pip install torch torchdrug torch-sparse
pip install ogb easydict pyyaml openai

Usage

To run A*Net, use the following command. The argument -c specifies the experiment configuration file, which includes the dataset, model architecture, and hyperparameters. You can find all configuration files in config/.../*.yaml. All the datasets will be automatically downloaded in the code.

python script/run.py -c config/transductive/fb15k237_astarnet.yaml --gpus [0]

For each experiment, you can specify the number of GPU via the argument --gpus. You may use --gpus null to run A*Net on a CPU, though it would be very slow. To run A*Net with multiple GPUs, launch the experiment with torchrun

torchrun --nproc_per_node=4 script/run.py -c config/transductive/fb15k237_astarnet.yaml --gpus [0,1,2,3]

For the inductive setting, there are 4 different splits for each dataset. You need to additionally specify the split version with --version v1.

ChatGPT Interface

We provide a ChatGPT interface of A*Net, where users can interact with A*Net through natural language. To play with the ChatGPT interface, download the checkpoint here and run the following command. Note you need an OpenAI API key to run the demo.

export OPENAI_API_KEY=your-openai-api-key
python script/chat.py -c config/transductive/wikikg2_astarnet_visualize.yaml --checkpoint wikikg2_astarnet.pth --gpus [0]

Visualization

A*Net supports visualization of important paths for its predictions. With a trained model, you can visualize the important paths with the following line. Please replace the checkpoint with your own path.

python script/visualize.py -c config/transductive/fb15k237_astarnet_visualize.yaml --checkpoint /path/to/astarnet/experiment/model_epoch_20.pth --gpus [0]

Parameterize with your favourite GNNs

A*Net is designed to be general frameworks for knowledge graph reasoning. This means you can parameterize it with a broad range of message-passing GNNs. To do so, just implement a convolution layer in reasoning/layer.py and register it with @R.register. The GNN layer is expected to have the following member functions

def message(self, graph, input):
    ...
    return message

def aggregate(self, graph, message):
    ...
    return update

def combine(self, input, update):
    ...
    return output

where the arguments and the return values are

  • graph (data.PackedGraph): a batch of subgraphs selected by A*Net, with graph.query being the query embeddings of shape (batch_size, input_dim).
  • input (Tensor): node representations of shape (graph.num_node, input_dim).
  • message (Tensor): messages of shape (graph.num_edge, input_dim).
  • update (Tensor): aggregated messages of shape (graph.num_node, *).
  • output (Tensor): output representations of shape (graph.num_node, output_dim).

To support the neural priority function in A*Net, we need to additionally provide an interface for computing messages

def compute_message(self, node_input, edge_input):
   ...
   return msg_output

You may refer to the following tutorials of TorchDrug

Frequently Asked Questions

  1. The code is stuck at the beginning of epoch 0.

    This is probably because the JIT cache is broken. Try rm -r ~/.cache/torch_extensions/* and run the code again.

Citation

If you find this project useful, please consider citing the following paper

@article{zhu2022scalable,
  title={A*Net: A Scalable Path-based Reasoning Approach for Knowledge Graphs},
  author={Zhu, Zhaocheng and Yuan, Xinyu and Galkin, Mikhail and Xhonneux, Sophie and Zhang, Ming and Gazeau, Maxime and Tang, Jian},
  journal={arXiv preprint arXiv:2206.04798},
  year={2022}
}

More Repositories

1

LiteratureDL4Graph

A comprehensive collection of recent papers on graph deep learning
3,068
star
2

torchdrug

A powerful and flexible machine learning platform for drug discovery
Python
1,382
star
3

graphvite

GraphVite: A General and High-performance Graph Embedding System
C++
1,207
star
4

KnowledgeGraphEmbedding

Python
1,184
star
5

RecommenderSystems

Python
1,058
star
6

ULTRA

A foundation model for knowledge graph reasoning
Python
420
star
7

GMNN

Graph Markov Neural Networks
Python
400
star
8

GearNet

GearNet and Geometric Pretraining Methods for Protein Structure Representation Learning, ICLR'2023 (https://arxiv.org/abs/2203.06125)
Python
265
star
9

NBFNet

Official implementation of Neural Bellman-Ford Networks (NeurIPS 2021)
Python
196
star
10

ConfGF

Implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021).
Python
159
star
11

pLogicNet

Python
143
star
12

RNNLogic

C++
123
star
13

GraphAny

GraphAny: A foundation model for node classification on any graph.
Python
101
star
14

GNN-QE

Official implementation of Graph Neural Network Query Executor (ICML 2022)
Python
89
star
15

PEER_Benchmark

PEER Benchmark, appear at NeurIPS 2022 Dataset and Benchmark Track (https://arxiv.org/abs/2206.02096)
Python
79
star
16

ESM-GearNet

ESM-GearNet for Protein Structure Representation Learning (https://arxiv.org/abs/2303.06275)
Python
75
star
17

DiffPack

Implementation of DiffPack: A Torsional Diffusion Model for Autoregressive Protein Side-Chain Packing
Python
71
star
18

GraphLoG

Implementation of Self-supervised Graph-level Representation Learning with Local and Global Structure (ICML 2021).
Python
68
star
19

ProtST

[ICML-23 ORAL] ProtST: Multi-Modality Learning of Protein Sequences and Biomedical Texts
Python
62
star
20

GraphAF

50
star
21

InductiveQE

Official implementation of Inductive Logical Query Answering in Knowledge Graphs (NeurIPS 2022)
Python
47
star
22

ContinuousGNN

Python
44
star
23

FewShotRE

Python
38
star
24

SiamDiff

Code for Pre-training Protein Encoder via Siamese Sequence-Structure Diffusion Trajectory Prediction (https://arxiv.org/abs/2301.12068)
Python
38
star
25

SPN

Python
29
star
26

GearBind

Pretrainable geometric graph neural network for antibody affinity maturation
Python
28
star
27

esm-s

Structure-Informed Protein Language Model
Python
26
star
28

DrugTutorial_AAAI2021

Tutorial for Drug Discovery on AAAI 2021.
CSS
8
star
29

DeepGraphLearning

Homepage
7
star
30

torchdrug-site

Website for TorchDrug
SCSS
6
star
31

GraphRepresentationLiterature

The literature on graph representation learning
4
star
32

ultra_torchdrug

A TorchDrug version of ULTRA for reproducibility
Python
4
star
33

AAAI19Tutorial

Tutorial "graph representation learning" given at AAAI'19
3
star
34

torchprotein-site

Website for TorchProtein
SCSS
3
star
35

coursewebsite

Course website for Deep Learning and Applications
CSS
2
star
36

Math80600A_2021W

Python
1
star