Skip to content

Commit

Permalink
[Example] Add DimeNet(++) for Molecular Graph Property Prediction (dm…
Browse files Browse the repository at this point in the history
…lc#2706)

* [example] arma

* update

* update

* update

* update

* update

* [example] dimenet

* [docs] update dimenet

* [docs] update tf results

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
xnuohz and mufeili authored Mar 9, 2021
1 parent c88fca5 commit 0b47e86
Show file tree
Hide file tree
Showing 23 changed files with 1,712 additions and 7 deletions.
15 changes: 10 additions & 5 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ The folder contains example implementations of selected research papers related
| [Dynamic Graph CNN for Learning on Point Clouds](#dgcnnpoint) | | | | | |
| [Supervised Community Detection with Line Graph Neural Networks](#lgnn) | | | | | |
| [Text Generation from Knowledge Graphs with Graph Transformers](#graphwriter) | | | | | |
| [Directional Message Passing for Molecular Graphs](#dimenet) | | | :heavy_check_mark: | | |
| [Link Prediction Based on Graph Neural Networks](#seal) | | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: |
| [Variational Graph Auto-Encoders](#vgae) | | :heavy_check_mark: | | | |

Expand All @@ -101,15 +102,19 @@ The folder contains example implementations of selected research papers related
- Example code: [Molecule embedding](https://github.com/awslabs/dgl-lifesci/tree/master/examples/molecule_embeddings), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)
- Tags: molecules, graph classification, unsupervised learning, self-supervised learning, molecular property prediction
- <a name="gnnfilm"></a> Marc Brockschmidt. GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation. [Paper link](https://arxiv.org/abs/1906.12192).
- Example code: [Pytorch](../examples/pytorch/GNN-FiLM)
- Example code: [PyTorch](../examples/pytorch/GNN-FiLM)
- Tags: multi-relational graphs, hypernetworks, GNN architectures
- <a name="gxn"></a> Li, Maosen, et al. Graph Cross Networks with Vertex Infomax Pooling. [Paper link](https://arxiv.org/abs/2010.01804).
- Example code: [Pytorch](../examples/pytorch/gxn)
- Example code: [PyTorch](../examples/pytorch/gxn)
- Tags: pooling, graph classification
- <a name="dagnn"></a> Liu et al. Towards Deeper Graph Neural Networks. [Paper link](https://arxiv.org/abs/2007.09296).
- Example code: [Pytorch](../examples/pytorch/dagnn)
- Example code: [PyTorch](../examples/pytorch/dagnn)
- Tags: over-smoothing, node classification

- <a name="dimenet"></a> Klicpera et al. Directional Message Passing for Molecular Graphs. [Paper link](https://arxiv.org/abs/2003.03123).
- Example code: [PyTorch](../examples/pytorch/dimenet)
- Tags: molecules, molecular property prediction, quantum chemistry

## 2019


Expand Down Expand Up @@ -180,10 +185,10 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](../examples/pytorch/hgp_sl)
- Tags: graph classification, pooling
- <a name='hardgat'></a> Gao, Hongyang, et al. Graph Representation Learning via Hard and Channel-Wise Attention Networks [Paper link](https://arxiv.org/abs/1907.04652).
- Example code: [Pytorch](../examples/pytorch/hardgat)
- Example code: [PyTorch](../examples/pytorch/hardgat)
- Tags: node classification, graph attention
- <a name='ngcf'></a> Wang, Xiang, et al. Neural Graph Collaborative Filtering. [Paper link](https://arxiv.org/abs/1905.08108).
- Example code: [Pytorch](../examples/pytorch/NGCF)
- Example code: [PyTorch](../examples/pytorch/NGCF)
- Tags: Collaborative Filtering, Recommendation, Graph Neural Network


Expand Down
128 changes: 128 additions & 0 deletions examples/pytorch/dimenet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# DGL Implementation of DimeNet and DimeNet++

This DGL example implements the GNN model proposed in the paper [Directional Message Passing for Molecular Graphs](https://arxiv.org/abs/2003.03123) and [Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules](https://arxiv.org/abs/2011.14115). For the original implementation, see [here](https://github.com/klicperajo/dimenet).

Contributor: [xnuohz](https://github.com/xnuohz)

* This example implements both DimeNet and DimeNet++.
* The advantages of DimeNet++ over DimeNet
- Fast interactions: replacing bilinear layer with a simple Hadamard priduct
- Embedding hierarchy: using a higher number of embeddings by reducing the embedding size in blocks via down- and up-projection layers
- Other improvements: using less interaction blocks

### Requirements
The codebase is implemented in Python 3.6. For version requirement of packages, see below.

```
click 7.1.2
dgl 0.6.0
logzero 1.6.3
numpy 1.19.5
ruamel.yaml 0.16.12
scikit-learn 0.24.1
scipy 1.5.4
sympy 1.7.1
torch 1.7.0
tqdm 4.56.0
```

### The graph datasets used in this example

The DGL's built-in QM9 dataset. Dataset summary:

* Number of Molecular Graphs: 130,831
* Number of Tasks: 12

### Usage

**Note: DimeNet++ is recommended to use over DimeNet.**

##### Examples

The following commands learn a neural network and predict on the test set.
Training a DimeNet model on QM9 dataset.
```bash
python main.py --model-cnf config/dimenet.yaml
```
Training a DimeNet++ model on QM9 dataset.
```bash
python main.py --model-cnf config/dimenet_pp.yaml
```
For faster experimentation, you should first put the author's [pretrained](https://github.com/klicperajo/dimenet/tree/master/pretrained) folder here, which contains pre-trained TensorFlow models. You can convert a TensorFlow model to a PyTorch model by using the following commands.
```
python convert_tf_ckpt_to_pytorch.py --model-cnf config/dimenet_pp.yaml --convert-cnf config/convert.yaml
```
Then you can set `flag: True` in `dimenet_pp.yaml` and run the above script, DimeNet++ will use the pretrained weights to predict on the test set.

##### Configuration

For more details, please see `config/dimenet.yaml` and `config/dimenet_pp.yaml`

###### Model options
```
// The following paramaters are only used in DimeNet++
out_emb_size int Output embedding size. Default is 256
int_emb_size int Input embedding size. Default is 64
basis_emb_size int Basis embedding size. Default is 8
extensive bool Readout operator for generating a graph-level representation. Default is True
// The following paramater is only used in DimeNet
num_bilinear int Third dimension of the bilinear layer tensor in DimeNet. Default is 8
// The following paramaters are used in both DimeNet and DimeNet++
emb_size int Embedding size used throughout the model. Default is 128
num_blocks int Number of building blocks to be stacked. Default is 6 in DimeNet and 4 in DimeNet++
num_spherical int Number of spherical harmonics. Default is 7
num_radial int Number of radial basis functions. Default is 6
envelope_exponent int Shape of the smooth cutoff. Default is 5
cutoff float Cutoff distance for interatomic interactions. Default is 5.0
num_before_skip int Number of residual layers in interaction block before skip connection. Default is 1
num_after_skip int Number of residual layers in interaction block after skip connection. Default is 2
num_dense_output int Number of dense layers for the output blocks. Default is 3
targets list List of targets to predict. Default is ['mu']
output_init string Initial function name for output layer. Default is 'GlorotOrthogonal'
```

###### Training options
```
num_train int Number of training samples. Default is 110000
num_valid int Number of validation samples. Default is 10000
data_seed int Random seed. Default is 42
lr float Learning rate. Default is 0.001
weight_decay float Weight decay. Default is 0.0001
ema_decay float EMA decay. Default is 0.
batch_size int Batch size. Default is 100
epochs int Training epochs. Default is 300
early_stopping int Patient epochs to wait before early stopping. Default is 20
num_workers int Number of subprocesses to use for data loading. Default is 18
gpu int GPU index. Default is 0, using CUDA:0
interval int Time intervals for model evaluation. Default is 50
step_size int Period of learning rate decay. Default is 100
gamma float Factor of learning rate decay. Default is 0.3
```

### Performance

- Batch size is different
- Linear learning rate warm-up is not used
- Exponential learning rate decay is not used
- Exponential moving average (EMA) is not used
- The values for tasks except mu, alpha, r2, Cv should be x 10^-3
- The author's code didn't provide the pretrained model for gap task
- MAE(DimeNet in Table 1) is from [here](https://arxiv.org/abs/2003.03123)
- MAE(DimeNet++ in Table 2) is from [here](https://arxiv.org/abs/2011.14115)

| Target | mu | alpha | homo | lumo | gap | r2 | zpve | U0 | U | H | G | Cv |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| MAE(DimeNet in Table 1) | 0.0286 | 0.0469 | 27.8 | 19.7 | 34.8 | 0.331 | 1.29 | 8.02 | 7.89 | 8.11 | 8.98 | 0.0249 |
| MAE(DimeNet++ in Table 2) | 0.0297 | 0.0435 | 24.6 | 19.5 | 32.6 | 0.331 | 1.21 | 6.32 | 6.28 | 6.53 | 7.56 | 0.0230 |
| MAE(DimeNet++, TF, pretrain) | 0.0297 | 0.0435 | 0.0246 | 0.0195 | - | 0.3312 | 0.00121 | 0.0063 | 0.00628 | 0.00653 | 0.00756 | 0.0230 |
| MAE(DimeNet++, TF, scratch) | 0.0330 | 0.0447 | 0.0251 | 0.0227 | 0.0486 | 0.3574 | 0.00123 | 0.0065 | 0.00635 | 0.00658 | 0.00747 | 0.0224 |
| MAE(DimeNet++, DGL) | 0.0326 | 0.0537 | 0.0311 | 0.0255 | 0.0490 | 0.4801 | 0.0043 | 0.0141 | 0.0109 | 0.0117 | 0.0150 | 0.0254 |

### Speed

| Model | Original Implementation | DGL Implementation | Improvement |
| :-: | :-: | :-: | :-: |
| DimeNet | 2839 | 1345 | 2.1x |
| DimeNet++ | 624 | 238 | 2.6x |
4 changes: 4 additions & 0 deletions examples/pytorch/dimenet/config/convert.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
tf:
ckpt_path: 'pretrained/dimenet_pp/mu'
torch:
dump_path: 'pretrained/converted'
35 changes: 35 additions & 0 deletions examples/pytorch/dimenet/config/dimenet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: "dimenet"

model:
emb_size: 128
num_blocks: 6
num_bilinear: 8
num_spherical: 7
num_radial: 6
envelope_exponent: 5
cutoff: 5.0
num_before_skip: 1
num_after_skip: 2
num_dense_output: 3
# ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
targets: ['U0']

train:
num_train: 110000
num_valid: 10000
data_seed: 42
lr: 0.001
weight_decay: 0.0001
ema_decay: 0
batch_size: 45
epochs: 300
early_stopping: 20
num_workers: 18
gpu: 0
interval: 50
step_size: 100
gamma: 0.3

pretrain:
flag: False
path: 'pretrained/converted/'
38 changes: 38 additions & 0 deletions examples/pytorch/dimenet/config/dimenet_pp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: "dimenet++"

model:
emb_size: 128
out_emb_size: 256
int_emb_size: 64
basis_emb_size: 8
num_blocks: 4
num_spherical: 7
num_radial: 6
envelope_exponent: 5
cutoff: 5.0
extensive: True
num_before_skip: 1
num_after_skip: 2
num_dense_output: 3
# ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
targets: ['mu']

train:
num_train: 110000
num_valid: 10000
data_seed: 42
lr: 0.001
weight_decay: 0.0001
ema_decay: 0
batch_size: 100
epochs: 300
early_stopping: 20
num_workers: 18
gpu: 0
interval: 50
step_size: 100
gamma: 0.3

pretrain:
flag: False
path: 'pretrained/converted/'
89 changes: 89 additions & 0 deletions examples/pytorch/dimenet/convert_tf_ckpt_to_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import tensorflow as tf
import torch
import torch.nn as nn
import click
import numpy as np
import os

from logzero import logger
from pathlib import Path
from ruamel.yaml import YAML
from modules.initializers import GlorotOrthogonal
from modules.dimenet_pp import DimeNetPP

@click.command()
@click.option('-m', '--model-cnf', type=click.Path(exists=True), help='Path of model config yaml.')
@click.option('-c', '--convert-cnf', type=click.Path(exists=True), help='Path of convert config yaml.')
def main(model_cnf, convert_cnf):
yaml = YAML(typ='safe')
model_cnf = yaml.load(Path(model_cnf))
convert_cnf = yaml.load(Path(convert_cnf))
model_name, model_params, _ = model_cnf['name'], model_cnf['model'], model_cnf['train']
logger.info(f'Model name: {model_name}')
logger.info(f'Model params: {model_params}')

if model_params['targets'] in ['mu', 'homo', 'lumo', 'gap', 'zpve']:
model_params['output_init'] = nn.init.zeros_
else:
# 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv
model_params['output_init'] = GlorotOrthogonal

# model initialization
logger.info('Loading Model')
model = DimeNetPP(emb_size=model_params['emb_size'],
out_emb_size=model_params['out_emb_size'],
int_emb_size=model_params['int_emb_size'],
basis_emb_size=model_params['basis_emb_size'],
num_blocks=model_params['num_blocks'],
num_spherical=model_params['num_spherical'],
num_radial=model_params['num_radial'],
cutoff=model_params['cutoff'],
envelope_exponent=model_params['envelope_exponent'],
num_before_skip=model_params['num_before_skip'],
num_after_skip=model_params['num_after_skip'],
num_dense_output=model_params['num_dense_output'],
num_targets=len(model_params['targets']),
extensive=model_params['extensive'],
output_init=model_params['output_init'])
logger.info(model.state_dict())
tf_path, torch_path = convert_cnf['tf']['ckpt_path'], convert_cnf['torch']['dump_path']
init_vars = tf.train.list_variables(tf_path)
tf_vars_dict = {}

# 147 keys
for name, shape in init_vars:
if name == '_CHECKPOINTABLE_OBJECT_GRAPH':
continue
array = tf.train.load_variable(tf_path, name)
logger.info(f'Loading TF weight {name} with shape {shape}')
tf_vars_dict[name] = array

for name, array in tf_vars_dict.items():
name = name.split('/')[:-2]
pointer = model

for m_name in name:
if m_name == 'kernel':
pointer = getattr(pointer, 'weight')
elif m_name == 'int_blocks':
pointer = getattr(pointer, 'interaction_blocks')
elif m_name == 'embeddings':
pointer = getattr(pointer, 'embedding')
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, m_name)
if name[-1] == 'kernel':
array = np.transpose(array)
assert array.shape == pointer.shape
logger.info(f'Initialize PyTorch weight {name}')
pointer.data = torch.from_numpy(array)

logger.info(f'Save PyTorch model to {torch_path}')
if not os.path.exists(torch_path):
os.makedirs(torch_path)
target = model_params['targets'][0]
torch.save(model.state_dict(), f'{torch_path}/{target}.pt')
logger.info(model.state_dict())

if __name__ == "__main__":
main()
Loading

0 comments on commit 0b47e86

Please sign in to comment.