Skip to content

Commit

Permalink
[DGL-LifeSci] Allow Generating Vocabulary from a New Dataset (dmlc#1577)
Browse files Browse the repository at this point in the history
* Generate vocabulary from a new dataset

* CI"
  • Loading branch information
mufeili authored Jun 2, 2020
1 parent a936f9d commit 38b9c0f
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 44 deletions.
89 changes: 46 additions & 43 deletions apps/life_sci/examples/generative_models/jtnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,62 +25,25 @@ molecules for training and 5000 molecules for validation.

### Preprocessing

Class `JTNNDataset` will process a SMILES into a dict, including the junction tree, graph with
Class `JTNNDataset` will process a SMILES string into a dict, consisting of a junction tree, a graph with
encoded nodes(atoms) and edges(bonds), and other information for model to use.

## Usage

### Training

To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint at the current working directory.
```
-s SAVE_PATH, Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, Path to load pre-trained model (default: None)
-b BATCH_SIZE, Batch size (default: 40)
-w HIDDEN_SIZE, Size of representation vectors (default: 200)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, Depth of message passing hops (default: 3)
-z BETA, Coefficient of KL Divergence term (default: 1.0)
-q LR, Learning Rate (default: 0.001)
```

Model will be saved periodically.
All training checkpoint will be stored at `SAVE_PATH`, passed by command line or by default.

#### Dataset configuration

If you want to use your own dataset, please create a file contains one SMILES a line,
and pass the file path to the `-t` or `--train` option.
```
-t TRAIN, --train TRAIN
Training file name (default: train)
```
with preprocessed vocabulary, and save model checkpoint periodically in the current working directory.

### Evaluation

To start evaluation, use `python reconstruct_eval.py`, and following arguments
```
-t TRAIN, Training file name (default: test)
-m MODEL_PATH, Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
```

And it would print out the success rate of reconstructing the same molecules.
To start evaluation, use `python reconstruct_eval.py`. By default, we will perform evaluation with
DGL's pre-trained model. During the evaluation, the program will print out the success rate of
molecule reconstruction.

### Pre-trained models

Below gives the statistics of pre-trained `JTNN_ZINC` model.
Below gives the statistics of our pre-trained `JTNN_ZINC` model.

| Pre-trained model | % Reconstruction Accuracy
| ------------------ | -------
Expand All @@ -96,3 +59,43 @@ Please put this script at the current directory (`examples/pytorch/model_zoo/che
![image](https://user-images.githubusercontent.com/8686776/63773593-0d37da00-c90e-11e9-8933-0abca4b430db.png)
#### Neighbor Molecules
![image](https://user-images.githubusercontent.com/8686776/63773602-1163f780-c90e-11e9-8341-5122dc0d0c82.png)

### Dataset configuration

If you want to use your own dataset, please create a file with one SMILES a line as below

```
CCO
Fc1ccccc1
```

You can generate the vocabulary file corresponding to your dataset with `python vocab.py -d X -v Y`, where `X`
is the path to the dataset and `Y` is the path to the vocabulary file to save. An example vocabulary file
corresponding to the two molecules above will be

```
CC
CF
C1=CC=CC=C1
CO
```

If you want to develop a model based on DGL's pre-trained model, it's important to make sure that the vocabulary
generated above is a subset of the vocabulary we use for the pre-trained model. By running `vocab.py` above, we
also check if the new vocabulary is a subset of the vocabulary we use for the pre-trained model and print the
result in the terminal as follows:

```
The new vocabulary is a subset of the default vocabulary: True
```

To train on this new dataset, run

```
python train.py -t X
```

where `X` is the path to the new dataset. If you want to use the vocabulary generated above, also add `-v Y`, where
`Y` is the path to the vocabulary file we just saved.

To evaluate on this new dataset, run `python reconstruct_eval.py` with arguments same as above.
46 changes: 46 additions & 0 deletions apps/life_sci/examples/generative_models/jtnn/vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Generate vocabulary for a new dataset."""
if __name__ == '__main__':
import argparse
import os
import rdkit

from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_archive

from jtnn.mol_tree import DGLMolTree

parser = argparse.ArgumentParser('Generate vocabulary for a molecule dataset')
parser.add_argument('-d', '--data-path', type=str,
help='Path to the dataset')
parser.add_argument('-v', '--vocab', type=str,
help='Path to the vocabulary file to save')
args = parser.parse_args()

lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)

vocab = set()
with open(args.data_path, 'r') as f:
for line in f:
smiles = line.strip()
mol = DGLMolTree(smiles)
for i in mol.nodes_dict:
vocab.add(mol.nodes_dict[i]['smiles'])

with open(args.vocab, 'w') as f:
for v in vocab:
f.write(v + '\n')

# Get the vocabulary used for the pre-trained model
default_dir = get_download_dir()
vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
if not os.path.exists(vocab_file):
zip_file_path = '{}/jtnn.zip'.format(default_dir)
download(_get_dgl_url('dgllife/jtnn.zip'), path=zip_file_path)
extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))
default_vocab = set()
with open(vocab_file, 'r') as f:
for line in f:
default_vocab.add(line.strip())

print('The new vocabulary is a subset of the default vocabulary: {}'.format(
vocab.issubset(default_vocab)))
1 change: 0 additions & 1 deletion tutorials/models/1_gnn/4_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def forward(self, g):

# load graph data
from dgl.contrib.data import load_data
import numpy as np
data = load_data(dataset='aifb')
num_nodes = data.num_nodes
num_rels = data.num_rels
Expand Down

0 comments on commit 38b9c0f

Please sign in to comment.