Skip to content

Commit

Permalink
[DGL-LifeSci] WLN for Reaction Prediction (dmlc#1530)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update

* Update

* Fix bug

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* UPdate

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Udpate

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Finalize
  • Loading branch information
mufeili authored May 16, 2020
1 parent 70dc2ee commit 039fefc
Show file tree
Hide file tree
Showing 18 changed files with 3,185 additions and 386 deletions.
13 changes: 7 additions & 6 deletions apps/life_sci/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True)

Below we provide some reference numbers to show how DGL improves the speed of training models per epoch in seconds.

| Model | Original Implementation | DGL Implementation | Improvement |
| ---------------------------------- | ----------------------- | ------------------ | ----------- |
| GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x |
| AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x |
| JTNN on ZINC | 1826 | 743 | 2.5x |
| WLN for reaction center prediction | 11657 | 5095 | 2.3x | |
| Model | Original Implementation | DGL Implementation | Improvement |
| ---------------------------------- | ----------------------- | -------------------------- | ---------------------------- |
| GCN on Tox21 | 5.5 (DeepChem) | 1.0 | 5.5x |
| AttentiveFP on Aromaticity | 6.0 | 1.2 | 5x |
| JTNN on ZINC | 1826 | 743 | 2.5x |
| WLN for reaction center prediction | 11657 | 858 (1 GPU) / 134 (8 GPUs) | 13.6x (1GPU) / 87.0x (8GPUs) |
| WLN for candidate ranking | 40122 | 22268 | 1.8x |
11 changes: 9 additions & 2 deletions apps/life_sci/docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,23 @@ Reaction Prediction
USPTO
`````

.. autoclass:: dgllife.data.USPTO
.. autoclass:: dgllife.data.USPTOCenter
:members: __getitem__, __len__
:show-inheritance:

.. autoclass:: dgllife.data.USPTORank
:members: ignore_large, __getitem__, __len__
:show-inheritance:

Adapting to New Datasets for Weisfeiler-Lehman Networks
```````````````````````````````````````````````````````

.. autoclass:: dgllife.data.WLNReactionDataset
.. autoclass:: dgllife.data.WLNCenterDataset
:members: __getitem__, __len__

.. autoclass:: dgllife.data.WLNRankDataset
:members: ignore_large, __getitem__, __len__

Protein-Ligand Binding Affinity Prediction
------------------------------------------

Expand Down
5 changes: 5 additions & 0 deletions apps/life_sci/docs/source/api/model.zoo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ WLN for Reaction Center Prediction
.. automodule:: dgllife.model.model_zoo.wln_reaction_center
:members:

WLN for Ranking Candidate Products
``````````````````````````````````
.. automodule:: dgllife.model.model_zoo.wln_reaction_ranking
:members:

Protein-Ligand Binding Affinity Prediction

ACNN
Expand Down
214 changes: 196 additions & 18 deletions apps/life_sci/examples/reaction_prediction/rexgen_direct/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ An earlier version of the work was published in NeurIPS 2017 as
["Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network"](https://arxiv.org/abs/1709.04555) with some
slight difference in modeling.

This work proposes a template-free approach for reaction prediction with 2 stages:
1) Identify reaction center (pairs of atoms that will lose a bond or form a bond)
2) Enumerate the possible combinations of bond changes and rank the corresponding candidate products

We provide a jupyter notebook for walking through a demonstration with our pre-trained models. You can
download it with `wget https://data.dgl.ai/dgllife/reaction_prediction_pretrained.ipynb` and you need to put it
in this directory. Below we visualize a reaction prediction by the model:

![](https://data.dgl.ai/dgllife/wln_reaction.png)

## Dataset

The example by default works with reactions from USPTO (United States Patent and Trademark) granted patents,
Expand All @@ -29,53 +39,104 @@ whose reaction centers have all been selected.
We use GPU whenever possible. To train the model with default options, simply do

```bash
python find_reaction_center.py
python find_reaction_center_train.py
```

Once the training process starts, the progress will be printed in the terminal as follows:

```bash
Epoch 1/50, iter 8150/20452 | loss 8.4788 | grad norm 12.9927
Epoch 1/50, iter 8200/20452 | loss 8.6722 | grad norm 14.0833
```

Once the training process starts, the progress will be printed out in the terminal as follows:
Everytime the learning rate is decayed (specified as `'decay_every'` in `configure.py`'s `reaction_center_config`), we save a checkpoint of
the model and evaluate the model on the validation set. The evaluation result is formatted as follows, where `total samples x` means
the we have trained the model on `x` samples and `acc@k` means top-k accuracy:

```bash
Epoch 1/50, iter 8150/20452 | time/minibatch 0.0260 | loss 8.4788 | grad norm 12.9927
Epoch 1/50, iter 8200/20452 | time/minibatch 0.0260 | loss 8.6722 | grad norm 14.0833
total samples 800000, (epoch 2/35, iter 2443/2557) | acc@12 0.9278 | acc@16 0.9419 | acc@20 0.9496 | acc@40 0.9596 | acc@80 0.9596 |
```

After an epoch of training is completed, we evaluate the model on the validation set and
print the evaluation results as follows:
All model check points and evaluation results can be found under `center_results`. `model_x.pkl` stores a model
checkpoint after seeing `x` training samples. `val_eval.txt` stores all
evaluation results on the validation set.

You may want to terminate the training process when the validation performance no longer improves for some time.

### Multi-GPU Training

By default we use one GPU only. We also allow multi-gpu training. To use GPUs with ids `id1,id2,...`, do

```bash
Epoch 4/50, validation | acc@10 0.8213 | acc@20 0.9016 |
python find_reaction_center_train.py --gpus id1,id2,...
```

By default, we store the model per 10000 iterations in `center_results`.
A summary of the training speedup with the DGL implementation is presented below.

**Speedup**: For an epoch of training, our implementation takes about 5095s for the first epoch while the authors'
implementation takes about 11657s, which is roughly a speedup by 2.3x.
| Item | Training time (s/epoch) | Speedup |
| ----------------------- | ----------------------- | ------- |
| Authors' implementation | 11657 | 1x |
| DGL with 1 gpu | 858 | 13.6x |
| DGL with 2 gpus | 443 | 26.3x |
| DGL with 4 gpus | 243 | 48.0x |
| DGL with 8 gpus | 134 | 87.0x |

### Evaluation

```bash
python find_reaction_center_eval.py --model-path X
```

For example, you can evaluate the model trained for 800000 samples by setting `X` to be
`center_results/model_800000.pkl`. The evaluation results will be stored at `center_results/test_eval.txt`.

For model evaluation, we can choose whether to exclude reactants not contributing heavy atoms to the product
(e.g. reagents and solvents) in top-k atom pair selection, which will make the task easier.
For the easier evaluation, do

```bash
python find_reaction_center.py --easy
python find_reaction_center_eval.py --easy
```

A summary of the model performance is as follows:
A summary of the model performance of various settings is as follows:

| Item | Top 6 accuracy | Top 8 accuracy | Top 10 accuracy |
| --------------- | -------------- | -------------- | --------------- |
| Paper | 89.8 | 92.0 | 93.3 |
| Hard evaluation | 88.8 | 91.6 | 92.9 |
| Easy evaluation | 91.0 | 93.7 | 94.9 |
| Hard evaluation from authors' code | 87.7 | 90.6 | 92.1 |
| Easy evaluation from authors' code | 90.0 | 92.8 | 94.2 |
| Hard evaluation | 88.9 | 91.7 | 93.1 |
| Easy evaluation | 91.2 | 93.8 | 95.0 |
| Hard evaluation for model trained on 8 gpus | 88.1 | 91.0 | 92.5 |
| Easy evaluation for model trained on 8 gpus | 90.3 | 93.3 | 94.6 |

1. We are able to match the results reported from authors' code for both single-gpu and multi-gpu training
2. While multi-gpu training provides a great speedup, the performance with the default hyperparameters drops slightly.

### Data Pre-processing with Multi-Processing

By default we use 32 processes for data pre-processing. If you encounter an error with
`BrokenPipeError: [Errno 32] Broken pipe`, you can specify a smaller number of processes with

```bash
python find_reaction_center_train.py -np X
```

```bash
python find_reaction_center_eval.py -np X
```

where `X` is the number of processes that you would like to use.

### Pre-trained Model

We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model, simply do

```bash
python find_reaction_center.py -p
python find_reaction_center_eval.py
```

### Adapting to a new dataset.
### Adapting to a New Dataset

New datasets should be processed such that each line corresponds to the SMILES for a reaction like below:

Expand All @@ -89,10 +150,127 @@ In addition, atom mapping information is provided.
You can then train a model on new datasets with

```bash
python find_reaction_center.py --train-path X --val-path Y --test-path Z
python find_reaction_center_train.py --train-path X --val-path Y
```

where `X`, `Y` are paths to the new training/validation as described above.

For evaluation,

```bash
python find_reaction_center_eval.py --eval-path Z
```

where `Z` is the path to the new test set as described above.

## Candidate Ranking

### Additional Dependency

In addition to RDKit, MolVS is an alternative for comparing whether two molecules are the same after sanitization.

- [molvs](https://molvs.readthedocs.io/en/latest/)

### Modeling

For candidate ranking, we assume that a model has been trained for reaction center prediction first.
The pipeline for predicting candidate products given a reaction proceeds as follows:
1. Select top-k bond changes for atom pairs in the reactants, ranked by the model for reaction center prediction.
By default, we use k=80 and exclude reactants not contributing heavy atoms to the ground truth product in
selecting top-k bond changes as in the paper.
2. Filter out candidate bond changes for bonds that are already in the reactants
3. Enumerate possible combinations of atom pairs with up to C pairs, which reflects the number of bond changes
(losing or forming a bond) in reactions. A statistical analysis in USPTO suggests that setting it to 5 is enough.
4. Filter out invalid combinations where 1) atoms in candidate bond changes are not connected or 2) an atom pair is
predicted to have different types of bond changes
(e.g. two atoms are predicted simultaneously to form a single and double bond) or 3) valence constraints are violated.
5. Apply the candidate bond changes for each valid combination and get the corresponding candidate products.
6. Construct molecular graphs for the reactants and candidate products, featurize their atoms and bonds.
7. Apply a Weisfeiler-Lehman Network to the molecular graphs for reactants and candidate products and score them

### Training with Default Options

We use GPU whenever possible. To train the model with default options, simply do

```bash
python candidate_ranking_train.py -cmp X
```

where `X` is the path to a trained model for reaction center prediction. You can use our
pre-trained model by not specifying `-cmp`.

Once the training process starts, the progress will be printed in the terminal as follows:

```bash
Epoch 6/6, iter 16439/20061 | time 1.1124 | accuracy 0.8500 | grad norm 5.3218
Epoch 6/6, iter 16440/20061 | time 1.1124 | accuracy 0.9500 | grad norm 2.1163
```

Everytime the learning rate is decayed (specified as `'decay_every'` in `configure.py`'s `candidate_ranking_config`),
we save a checkpoint of the model and evaluate the model on the validation set. The evaluation result is formatted
as follows, where `total samples x` means that we have trained the model for `x` samples, `acc@k` means top-k accuracy,
`gfound` means the proportion of reactions where the ground truth product can be recovered by the ground truth bond changes.
We perform the evaluation based on RDKit-sanitized molecule equivalence (marked with `[strict]`) and MOLVS-sanitized
molecule equivalence (marked with `[molvs]`).

```bash
total samples 100000, (epoch 1/20, iter 5000/20061)
[strict] acc@1: 0.7732 acc@2: 0.8466 acc@3: 0.8763 acc@5: 0.8987 gfound 0.9864
[molvs] acc@1: 0.7779 acc@2: 0.8523 acc@3: 0.8826 acc@5: 0.9057 gfound 0.9953
```

All model check points and evaluation results can be found under `candidate_results`. `model_x.pkl` stores a model
checkpoint after seeing `x` training samples in total. `val_eval.txt` stores all
evaluation results on the validation set.

You may want to terminate the training process when the validation performance no longer improves for some time.

### Evaluation

```bash
python candidate_ranking_eval.py --model-path X -cmp Y
```

where `X` is the path to a trained model for candidate ranking and `Y` is the path to a trained model
for reaction center prediction. For example, you can evaluate the model trained for 800000 samples by setting `X` to be
`candidate_results/model_800000.pkl`. The evaluation results will be stored at `candidate_results/test_eval.txt`. As
in training, you can use our pre-trained model by not specifying `-cmp`.

A summary of the model performance of various settings is as follows:

| Item | Top 1 accuracy | Top 2 accuracy | Top 3 accuracy | Top 5 accuracy |
| -------------------------- | -------------- | -------------- | -------------- | -------------- |
| Authors' strict evaluation | 85.6 | 90.5 | 92.8 | 93.4 |
| DGL's strict evaluation | 85.6 | 90.0 | 91.7 | 92.9 |
| Authors' molvs evaluation | 86.2 | 91.2 | 92.8 | 94.2 |
| DGL's molvs evaluation | 86.1 | 90.6 | 92.4 | 93.6 |

### Pre-trained Model

We provide a pre-trained model so users do not need to train from scratch. To evaluate the pre-trained model,
simply do

```bash
python candidate_ranking_eval.py
```

### Adapting to a New Dataset

You can train a model on new datasets with

```bash
python candidate_ranking_train.py --train-path X --val-path Y
```

where `X`, `Y` are paths to the new training/validation set as described in the `Reaction Center Prediction` section.

For evaluation,

```bash
python candidate_ranking_train.py --eval-path Z
```

where `X`, `Y`, `Z` are paths to the new training/validation/test set as described above.
where `Z` is the path to the new test set as described in the `Reaction Center Prediction` section.

## References

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch

from dgllife.data import USPTORank, WLNRankDataset
from dgllife.model import WLNReactionRanking, load_pretrained
from torch.utils.data import DataLoader

from configure import candidate_ranking_config, reaction_center_config
from utils import mkdir_p, prepare_reaction_center, collate_rank_eval, candidate_ranking_eval

def main(args, path_to_candidate_bonds):
if args['test_path'] is None:
test_set = USPTORank(
subset='test', candidate_bond_path=path_to_candidate_bonds['test'],
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])
else:
test_set = WLNRankDataset(
raw_file_path=args['test_path'],
candidate_bond_path=path_to_candidate_bonds['test'], mode='test',
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
num_processes=args['num_processes'])

test_loader = DataLoader(test_set, batch_size=1, collate_fn=collate_rank_eval,
shuffle=False, num_workers=args['num_workers'])

if args['model_path'] is None:
model = load_pretrained('wln_rank_uspto')
else:
model = WLNReactionRanking(
node_in_feats=args['node_in_feats'],
edge_in_feats=args['edge_in_feats'],
node_hidden_feats=args['hidden_size'],
num_encode_gnn_layers=args['num_encode_gnn_layers'])
model.load_state_dict(torch.load(
args['model_path'], map_location='cpu')['model_state_dict'])
model = model.to(args['device'])

prediction_summary = candidate_ranking_eval(args, model, test_loader)
with open(args['result_path'] + '/test_eval.txt', 'w') as f:
f.write(prediction_summary)

if __name__ == '__main__':
from argparse import ArgumentParser

parser = ArgumentParser(description='Candidate Ranking')
parser.add_argument('--model-path', type=str, default=None,
help='Path to saved model. If None, we will directly evaluate '
'a pretrained model on the test set.')
parser.add_argument('--result-path', type=str, default='candidate_results',
help='Path to save modeling results')
parser.add_argument('--test-path', type=str, default=None,
help='Path to a new test set. '
'If None, we will use the default test set in USPTO.')
parser.add_argument('-cmp', '--center-model-path', type=str, default=None,
help='Path to a pre-trained model for reaction center prediction. '
'By default we use the official pre-trained model. If not None, '
'the model should follow the hyperparameters specified in '
'reaction_center_config.')
parser.add_argument('-rcb', '--reaction-center-batch-size', type=int, default=200,
help='Batch size to use for preparing candidate bonds from a trained '
'model on reaction center prediction')
parser.add_argument('-np', '--num-processes', type=int, default=8,
help='Number of processes to use for data pre-processing')
parser.add_argument('-nw', '--num-workers', type=int, default=32,
help='Number of workers to use for data loading in PyTorch data loader')
args = parser.parse_args().__dict__
args.update(candidate_ranking_config)
mkdir_p(args['result_path'])
if torch.cuda.is_available():
args['device'] = torch.device('cuda:0')
else:
args['device'] = torch.device('cpu')

path_to_candidate_bonds = prepare_reaction_center(args, reaction_center_config)
main(args, path_to_candidate_bonds)
Loading

0 comments on commit 039fefc

Please sign in to comment.