• Stars
    star
    149
  • Rank 248,619 (Top 5 %)
  • Language
    Python
  • Created over 3 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Making self-supervised learning work on molecules by using their 3D geometry to pre-train GNNs. Implemented in DGL and Pytorch Geometric.

3D Infomax improves GNNs for Molecular Property Prediction

Video | Paper

We pre-train GNNs to understand the geometry of molecules given only their 2D molecular graph which they can use for better molecular property predictions. Below is a 3 step guide for how to use the code and how to reproduce our results and a guide for creating molecular fingerprints. If you have questions, don't hesitate to open an issue or ask me via [email protected] or social media. I am happy to hear from you!

This repository additionally adapts different self-supervised learning methods to graphs such as "Bootstrap your own Latent", "Barlow Twins", or "VICReg".

Generating fingerprints for arbitrary SMILES

To generate fingerprints that carry 3D information, just set up the environment as in step 1 below, then place your SMILES into the file dataset/inference_smiles.txt and run

python inference.py --config=configs_clean/fingerprint_inference.yml

Your fingerprints are saved as pickle file into the dataset_directory

Step 1: Setup Environment

We will set up the environment using Anaconda. Clone the current repo

git clone https://github.com/HannesStark/3DInfomax

Create a new environment with all required packages using environment.yml (this can take a while). While in the project directory run:

conda env create

Activate the environment

conda activate 3DInfomax

Step 2: 3D Pre-train a model

Let's pre-train a GNN with 50 000 molecules and their structures from the QM9 dataset (you can also skip to Step 3 and use the pre-trained model weights provided in this repo). For other datasets see the Data section below.

python train.py --config=configs_clean/pre-train_QM9.yml

This will first create the processed data of dataset/QM9/qm9.csv with the 3D information in qm9_eV.npz. Then your model starts pre-training and all the logs are saved in the runs folder which will also contain the pre-trained model as best_checkpoint.pt that can later be loaded for fine-tuning.

You can start tensorboard and navigate to localhost:6006 in your browser to monitor the training process:

tensorboard --logdir=runs --port=6006

Explanation:

The config files in configs_clean provide additional examples and blueprints to train different models. The files always contain a model_type that should be pre-trained (2D network) and a model3d_type (3D network) where you can specify the parameters of these networks. To find out more about all the other parameters in the config file, have a look at their description by running python train.py --help.

Step 3: Fine-tune a model

During pre-training a directory is created in the runs directory that contains the pre-trained model. We provide an example of such a directory with already pre-trained weights runs/PNA_qmugs_NTXentMultiplePositives_620000_123_25-08_09-19-52 which we can fine-tune for predicting QM9's homo property as follows.

python train.py --config=configs_clean/tune_QM9_homo.yml

You can monitor the fine-tuning process on tensorboard as well and in the end the results will be printed to the console but also saved in the runs directory that was created for fine-tuning in the file evaluation_test.txt.

The model which we are fine-tuning from is specified in configs_clean/tune_QM9_homo.yml via the parameter:

pretrain_checkpoint: runs/PNA_qmugs_NTXentMultiplePositives_620000_123_25-08_09-19-52/best_checkpoint_35epochs.pt

Multiple seeds:

This is a second fine-tuning example where we predict non-quantum properties of the OGB datasets and train multiple seeds (we always use the seeds 1, 2, 3, 4, 5, 6 in our experiments):

python train.py --config=configs_clean/tune_freesolv.yml

After all runs are done, the averaged results are saved in the runs directory of each seed in the file multiple_seed_test_statistics.txt

Data

You can pre-train or fine-tune on different datasets by specifying the dataset: parameter in a .yml file such as dataset: drugs to use GEOM-Drugs.

The QM9 dataset and the OGB datasets are already provided with this repository. The QMugs and GEOM-Drugs datasets need to be downloaded and placed in the correct location.

GEOM-Drugs: Download GEOM-Drugs here ( the rdkit_folder.tar.gz file), unzip it, and place it into dataset/GEOM.

QMugs: Download QMugs here (the structures.tar and summary.csv files), unzip the structures.tar, and place the resulting structures folder and the summary.csv file into a new folder QMugs that you have to create NEXT TO the repository root. Not in the repository root (sorry for this).

Reference

📃 Paper on arXiv

@article{stark2021_3dinfomax,
  title={3D Infomax improves GNNs for Molecular Property Prediction},
  author={Hannes Stärk and Dominique Beaini and Gabriele Corso and Prudencio Tossou and Christian Dallago and Stephan Günnemann and Pietro Liò},
  journal={arXiv preprint arXiv:2110.04126},
  year={2021}
}

More Repositories

1

EquiBind

EquiBind: geometric deep learning for fast predictions of the 3D structure in which a small molecule binds to a protein
Python
473
star
2

FlowSite

Implementation of FlowSite and HarmonicFlow from the paper "Harmonic Self-Conditioned Flow Matching for Multi-Ligand Docking and Binding Site Design"
Python
83
star
3

dirichlet-flow-matching

Python
75
star
4

SMPL-NeRF

Embed human pose information into neural radiance fields (NeRF) to render images of humans in desired poses 🏃 from novel views
Python
58
star
5

protein-localization

Using Transformer protein embeddings with a linear attention mechanism to make SOTA de-novo predictions for the subcellular location of proteins 🔬
Jupyter Notebook
54
star
6

gnn-reinforcement-learning

Representing robots as graphs for reinforcement-learning in PyBullet locomotion environments.
Jupyter Notebook
26
star
7

hannes-stark

Code for my website built with Angular and running on GitHub Pages.
HTML
13
star
8

GNN-primer

Jupyter Notebook
8
star
9

CodonMPNN

Python
6
star
10

attention-to-binding-sites

Unsupervised method for binding site prediction using attention patterns of protein language models.
Jupyter Notebook
3
star
11

molecule-ELECTRA

Pre-train and evaluate Graph Neural Networks or Transformers on molecules with the ELECTRA method.
Python
3
star
12

audioImprovement

Removing background noise from clips of speech and improving audio quality (PyTorch)
Python
3
star
13

genie

Python
2
star
14

bachelorThesis

TensorFlow code and LaTex for Bachelor Thesis: Understanding Variational Autoencoders' Latent Representations of Remote Sensing Images 🌍
TeX
1
star
15

ec-number-prediction

Using similarity in embedding space for predicting EC numbers
Jupyter Notebook
1
star
16

dependencyNodeRanking

R code for NetworkCentralityCalculator. A web-tool with 5 different centrality measures. LaTex and pdf for documentation and explanation of different measures with a focus on "dependency centrality".
TeX
1
star
17

logag

1
star
18

HannesStark

1
star