Skip to content

Commit

Permalink
[Model Zoo] Refactor Model Zoo for Chemistry (dmlc#839)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update fix

* Update

* Update

* Refactor

* Update

* Update

* Update

* Update

* Update

* Update

* Fix style
  • Loading branch information
mufeili authored Sep 7, 2019
1 parent 5b41768 commit 189c2c0
Show file tree
Hide file tree
Showing 24 changed files with 936 additions and 772 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
[Documentation](https://docs.dgl.ai) | [DGL at a glance](https://docs.dgl.ai/tutorials/basics/1_first.html#sphx-glr-tutorials-basics-1-first-py) |
[Model Tutorials](https://docs.dgl.ai/tutorials/models/index.html) | [Discussion Forum](https://discuss.dgl.ai)

Model Zoos: [Chemistry](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo) | [Citation Networks](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo/citation_network)

DGL is a Python package that interfaces between existing tensor libraries and data being expressed as
graphs.

It makes implementing graph neural networks (including Graph Convolution Networks, TreeLSTM, and many others) easy while
maintaining high computation efficiency.

A summary of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations:
All model examples can be found [here](https://github.com/dmlc/dgl/tree/master/examples).

A summary of part of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations:

| Model | Reported <br> Accuracy | DGL <br> Accuracy | Author's training speed (epoch time) | DGL speed (epoch time) | Improvement |
| ----- | ----------------- | ------------ | ------------------------------------ | ---------------------- | ----------- |
Expand Down
55 changes: 55 additions & 0 deletions docs/source/api/python/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ Utils
utils.download
utils.check_sha1
utils.extract_archive
utils.split_dataset

.. autoclass:: dgl.data.utils.Subset
:members: __getitem__, __len__

Dataset Classes
---------------
Expand Down Expand Up @@ -57,3 +61,54 @@ Protein-Protein Interaction dataset

.. autoclass:: PPIDataset
:members: __getitem__, __len__

Molecular Graphs
----------------

To work on molecular graphs, make sure you have installed `RDKit 2018.09.3 <https://www.rdkit.org/docs/Install.html>`__.

Featurization
`````````````

For the use of graph neural networks, we need to featurize nodes (atoms) and edges (bonds). Below we list some
featurization methods/utilities:

.. autosummary::
:toctree: ../../generated/

chem.one_hot_encoding
chem.BaseAtomFeaturizer
chem.CanonicalAtomFeaturizer

Graph Construction
``````````````````

Several methods for constructing DGLGraphs from SMILES/RDKit molecule objects are listed below:

.. autosummary::
:toctree: ../../generated/

chem.mol_to_graph
chem.smile_to_bigraph
chem.mol_to_bigraph
chem.smile_to_complete_graph
chem.mol_to_complete_graph

Dataset Classes
```````````````

If your dataset is stored in a ``.csv`` file, you may find it helpful to use

.. autoclass:: dgl.data.chem.CSVDataset
:members: __getitem__, __len__

Currently two datasets are supported:

* Tox21
* TencentAlchemyDataset

.. autoclass:: dgl.data.chem.Tox21
:members: __getitem__, __len__, task_pos_weights

.. autoclass:: dgl.data.chem.TencentAlchemyDataset
:members: __getitem__, __len__, set_mean_and_std
1 change: 1 addition & 0 deletions docs/source/api/python/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ API Reference
graph_store
nodeflow
random
model_zoo
57 changes: 57 additions & 0 deletions docs/source/api/python/model_zoo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
.. _apimodelzoo:

Model Zoo
=========

.. currentmodule:: dgl.model_zoo

Chemistry
---------

Utils
`````

.. autosummary::
:toctree: ../../generated/

chem.load_pretrained

Property Prediction
```````````````````

Currently supported model architectures:

* GCNClassifier
* GATClassifier
* MPNN
* SchNet
* MGCN

.. autoclass:: dgl.model_zoo.chem.GCNClassifier
:members: forward

.. autoclass:: dgl.model_zoo.chem.GATClassifier
:members: forward

.. autoclass:: dgl.model_zoo.chem.MPNNModel
:members: forward

.. autoclass:: dgl.model_zoo.chem.SchNet
:members: forward

.. autoclass:: dgl.model_zoo.chem.MGCNModel
:members: forward

Generative Models
`````````````````

Currently supported model architectures:

* DGMG
* JTNN

.. autoclass:: dgl.model_zoo.chem.DGMG
:members: forward

.. autoclass:: dgl.model_zoo.chem.DGLJTNNVAE
:members: forward
4 changes: 4 additions & 0 deletions examples/pytorch/model_zoo/chem/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ are also two accompanying review papers that are well written [7], [8].
### Models
- **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by
progressively adding atoms and bonds.
- **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]:

### Example Usage of Pre-trained Models

Expand Down Expand Up @@ -143,3 +144,6 @@ Machine Learning* JMLR. 1263-1272.
[11] Li et al. (2018) Learning Deep Generative Models of Graphs. *arXiv preprint arXiv:1803.03324*.

[12] Goh et al. (2017) Deep learning for computational chemistry. *Journal of Computational Chemistry* 16, 1291-1307.

[13] Jin et al. (2018) Junction Tree Variational Autoencoder for Molecular Graph Generation.
*Proceedings of the 35th International Conference on Machine Learning (ICML)*, 2323-2332.
51 changes: 28 additions & 23 deletions examples/pytorch/model_zoo/chem/property_prediction/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ into training, validation and test set with a 80/10/10 ratio. By default we foll
- **Graph Convolutional Network** [2], [3]. Graph Convolutional Networks (GCN) have been one of the most popular graph neural
networks and they can be easily extended for graph level prediction. MoleculeNet [1] reports baseline results of graph
convolutions over multiple datasets.
- **Graph Attention Networks** [7]: Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
- **Graph Attention Networks** [7]. Graph Attention Networks (GATs) incorporate multi-head attention into GCNs,
explicitly modeling the interactions between adjacent atoms.

### Usage
Expand Down Expand Up @@ -49,16 +49,11 @@ a real difference.
| ---------------- | ---------------------- |
| Pretrained model | 0.827 |

## Dataset Customization

To customize your own dataset, see the instructions
[here](https://github.com/dmlc/dgl/tree/master/python/dgl/data/chem).

## Regression

Regression tasks require assigning continuous labels to a molecule, e.g. molecular energy.

### Dataset
### Datasets

- **Alchemy**. The [Alchemy Dataset](https://alchemy.tencent.com/) is introduced by Tencent Quantum Lab to facilitate the development of new
machine learning models useful for chemistry and materials science. The dataset lists 12 quantum mechanical properties of 130,000+ organic
Expand All @@ -68,29 +63,39 @@ These properties have been calculated using the open-source computational chemis

### Models

- **SchNet**: SchNet is a novel deep learning architecture modeling quantum interactions in molecules which utilize the continuous-filter
convolutional layers [4].
- **Multilevel Graph Convolutional neural Network**: Multilevel Graph Convolutional neural Network (MGCN) is a hierarchical
graph neural network directly extracts features from the conformation and spatial information followed by the multilevel interactions [5].
- **Message Passing Neural Network**: Message Passing Neural Network (MPNN) is a network with edge network (enn) as front end
and Set2Set for output prediction [6].
- **Message Passing Neural Network** [6]. Message Passing Neural Networks (MPNNs) have reached the best performance on
the QM9 dataset for some time.
- **SchNet** [4]. SchNet employs continuous filter convolutional layers to model quantum interactions in molecules
without requiring them to lie on grids.
- **Multilevel Graph Convolutional Neural Network** [5]. Multilevel Graph Convolutional Neural Networks (MGCN) are
hierarchical graph neural networks that extract features from the conformation and spatial information followed by the
multilevel interactions.

### Usage

```py
python regression.py --model sch --epoch 200
```
The model option must be one of 'sch', 'mgcn' or 'mpnn'.
Use `regression.py` with arguments
```
-m {MPNN,SCHNET,MGCN}, Model to use
-d {Alchemy}, Dataset to use
```

### Performance

#### Alchemy
#### Alchemy

|Model |Mean Absolute Error (MAE)|
|-------------|-------------------------|
|SchNet[4] |0.065|
|MGCN[5] |0.050|
|MPNN[6] |0.056|
The Alchemy contest is still ongoing. Before the test set is fully released, we only include the performance numbers
on the training and validation set for reference.

| Model | Training MAE | Validation MAE |
| ---------- | ------------ | -------------- |
| SchNet [4] | 0.2665 | 0.6139 |
| MGCN [5] | 0.2395 | 0.6463 |
| MPNN [6] | 0.2452 | 0.6259 |

## Dataset Customization

To customize your own dataset, see the instructions
[here](https://github.com/dmlc/dgl/tree/master/python/dgl/data/chem).

## References
[1] Wu et al. (2017) MoleculeNet: a benchmark for molecular machine learning. *Chemical Science* 9, 513-530.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dgl.data.utils import split_dataset
from dgl import model_zoo
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

from utils import Meter, EarlyStopping, collate_molgraphs, set_random_seed
from dgl import model_zoo
from dgl.data.utils import split_dataset

from utils import Meter, EarlyStopping, collate_molgraphs_for_classification, set_random_seed

def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
model.train()
Expand Down Expand Up @@ -45,15 +46,18 @@ def main(args):
args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
set_random_seed()

# Interchangeable with other Dataset
# Interchangeable with other datasets
if args['dataset'] == 'Tox21':
from dgl.data.chem import Tox21
dataset = Tox21()

trainset, valset, testset = split_dataset(dataset, args['train_val_test_split'])
train_loader = DataLoader(trainset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
val_loader = DataLoader(valset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
test_loader = DataLoader(testset, batch_size=args['batch_size'], collate_fn=collate_molgraphs)
train_loader = DataLoader(trainset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
val_loader = DataLoader(valset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)
test_loader = DataLoader(testset, batch_size=args['batch_size'],
collate_fn=collate_molgraphs_for_classification)

if args['pre_trained']:
args['num_epochs'] = 0
Expand Down
31 changes: 30 additions & 1 deletion examples/pytorch/model_zoo/chem/property_prediction/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,38 @@
'patience': 10
}

MPNN_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
}

SCHNET_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
}

MGCN_Alchemy = {
'batch_size': 16,
'num_epochs': 250,
'norm': True,
'output_dim': 12,
'lr': 0.0001,
'patience': 50
}

experiment_configures = {
'GCN_Tox21': GCN_Tox21,
'GAT_Tox21': GAT_Tox21
'GAT_Tox21': GAT_Tox21,
'MPNN_Alchemy': MPNN_Alchemy,
'SCHNET_Alchemy': SCHNET_Alchemy,
'MGCN_Alchemy': MGCN_Alchemy
}

def get_exp_configure(exp_name):
Expand Down
Loading

0 comments on commit 189c2c0

Please sign in to comment.