Skip to content

Commit

Permalink
[Model Zoo] Fix JTNN (dmlc#843)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update

* Update
  • Loading branch information
mufeili authored Sep 10, 2019
1 parent 4e0e669 commit 9df8cd3
Show file tree
Hide file tree
Showing 24 changed files with 62 additions and 158 deletions.
15 changes: 9 additions & 6 deletions examples/pytorch/jtnn/jtnn/chemutils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import rdkit
import rdkit.Chem as Chem
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers

MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
Expand All @@ -29,7 +28,8 @@ def decode_stereo(smiles2D):
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]

chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms()
if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
Expand Down Expand Up @@ -117,7 +117,8 @@ def tree_decomp(mol):
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
# In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
Expand Down Expand Up @@ -242,11 +243,13 @@ def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()),
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append( new_amap )

if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()),
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append( new_amap )
return att_confs

Expand Down
1 change: 0 additions & 1 deletion examples/pytorch/jtnn/jtnn/datautils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch.utils.data import Dataset
import numpy as np

import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
Expand Down
9 changes: 3 additions & 6 deletions examples/pytorch/jtnn/jtnn/jtmpn.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import torch
import torch.nn as nn
from .nnutils import cuda
from .chemutils import get_mol
#from mpn import atom_features, bond_features, ATOM_FDIM, BOND_FDIM
import rdkit.Chem as Chem
from dgl import DGLGraph, batch, unbatch, mean_nodes
from dgl import DGLGraph, mean_nodes
import dgl.function as DGLF
from .line_profiler_integration import profile
import os
import numpy as np

ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']

ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM = 5
Expand Down
6 changes: 0 additions & 6 deletions examples/pytorch/jtnn/jtnn/jtnn_dec.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda
import copy
import itertools
from dgl import batch, dfs_labeled_edges_generator
import dgl.function as DGLF
import networkx as nx
from .line_profiler_integration import profile
import numpy as np

MAX_NB = 8
Expand Down Expand Up @@ -265,7 +260,6 @@ def decode(self, mol_vec):
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data
wid = udata['wid']
x = udata['x']
h = udata['h']

Expand Down
7 changes: 1 addition & 6 deletions examples/pytorch/jtnn/jtnn/jtnn_enc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import torch
import torch.nn as nn
from collections import deque
from .mol_tree import Vocab
from .nnutils import GRUUpdate, cuda
import itertools
import networkx as nx
from dgl import batch, unbatch, bfs_edges_generator
from dgl import batch, bfs_edges_generator
import dgl.function as DGLF
from .line_profiler_integration import profile
import numpy as np

MAX_NB = 8
Expand Down
8 changes: 1 addition & 7 deletions examples/pytorch/jtnn/jtnn/jtnn_vae.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree import Vocab
from .nnutils import cuda, move_dgl_to_cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \
attach_mols_nx, decode_stereo
Expand All @@ -11,13 +10,9 @@
from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .line_profiler_integration import profile

import rdkit
import rdkit.Chem as Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
import copy, math
import copy

from dgl import batch, unbatch

Expand Down Expand Up @@ -102,7 +97,6 @@ def forward(self, mol_batch, beta=0, e1=None, e2=None):
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)

all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss

return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
Expand Down
1 change: 0 additions & 1 deletion examples/pytorch/jtnn/jtnn/mol_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import rdkit
import rdkit.Chem as Chem
import copy

Expand Down
1 change: 0 additions & 1 deletion examples/pytorch/jtnn/jtnn/mol_tree_nx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
import numpy as np
from .line_profiler_integration import profile

class DGLMolTree(DGLGraph):
def __init__(self, smiles):
Expand Down
10 changes: 3 additions & 7 deletions examples/pytorch/jtnn/jtnn/mpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@
import torch.nn as nn
import rdkit.Chem as Chem
import torch.nn.functional as F
from .nnutils import *
from .chemutils import get_mol
from networkx import Graph, DiGraph, convert_node_labels_to_integers
from dgl import DGLGraph, batch, unbatch, mean_nodes
from dgl import DGLGraph, mean_nodes
import dgl.function as DGLF
from functools import partial
from .line_profiler_integration import profile
import numpy as np

ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']

ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6
Expand Down
1 change: 0 additions & 1 deletion examples/pytorch/jtnn/jtnn/nnutils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
import os


Expand Down
4 changes: 3 additions & 1 deletion examples/pytorch/model_zoo/chem/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ are also two accompanying review papers that are well written [7], [8].
### Models
- **Deep Generative Models of Graphs (DGMG)** [11]: A very general framework for graph distribution learning by
progressively adding atoms and bonds.
- **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]:
- **Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN)** [13]: JTNNs are able to incrementally
expand molecules while maintaining chemical valency at every step. They can be used for both molecule generation and
optimization.

### Example Usage of Pre-trained Models

Expand Down
54 changes: 21 additions & 33 deletions examples/pytorch/model_zoo/chem/generative_models/jtnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,16 @@ encoded nodes(atoms) and edges(bonds), and other information for model to use.
To start training, use `python train.py`. By default, the script will use ZINC dataset
with preprocessed vocabulary, and save model checkpoint at the current working directory.
```
-s SAVE_PATH, --save_dir SAVE_PATH
Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, --model MODEL_PATH
Path to load pre-trained model (default: None)
-b BATCH_SIZE, --batch BATCH_SIZE
Batch size (default: 40)
-w HIDDEN_SIZE, --hidden HIDDEN_SIZE
Size of representation vectors (default: 200)
-l LATENT_SIZE, --latent LATENT_SIZE
Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, --depth DEPTH
Depth of message passing hops (default: 3)
-z BETA, --beta BETA Coefficient of KL Divergence term (default: 1.0)
-q LR, --lr LR Learning Rate (default: 0.001)
-T, --test Add this flag to run test mode (default: False)
-s SAVE_PATH, Path to save checkpoint models, default to be current
working directory (default: ./)
-m MODEL_PATH, Path to load pre-trained model (default: None)
-b BATCH_SIZE, Batch size (default: 40)
-w HIDDEN_SIZE, Size of representation vectors (default: 200)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features (default: 56)
-d DEPTH, Depth of message passing hops (default: 3)
-z BETA, Coefficient of KL Divergence term (default: 1.0)
-q LR, Learning Rate (default: 0.001)
```

Model will be saved periodically.
Expand All @@ -70,22 +63,17 @@ If you want to use your own dataset, please create a file contains one SMILES a

To start evaluation, use `python reconstruct_eval.py`, and following arguments
```
-t TRAIN, --train TRAIN
Training file name (default: test)
-m MODEL_PATH, --model MODEL_PATH
Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, --hidden HIDDEN_SIZE
Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, --latent LATENT_SIZE
Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, --depth DEPTH
Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
-t TRAIN, Training file name (default: test)
-m MODEL_PATH, Pre-trained model to be loaded for evalutaion. If not
specified, would use pre-trained model from model zoo
(default: None)
-w HIDDEN_SIZE, Hidden size of representation vector, should be
consistent with pre-trained model (default: 450)
-l LATENT_SIZE, Latent Size of node(atom) features and edge(atom)
features, should be consistent with pre-trained model
(default: 56)
-d DEPTH, Depth of message passing hops, should be consistent
with pre-trained model (default: 3)
```

And it would print out the success rate of reconstructing the same molecules.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import rdkit
import rdkit.Chem as Chem
import torch
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from dgl import DGLGraph

ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
Expand Down Expand Up @@ -406,16 +405,12 @@ def bond_features(bond):
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
n_edges = 0
atom_x = []
bond_x = []

for mol, mol_tree, ctr_node_id in cand_batch:
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()

ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph()

for i, atom in enumerate(mol.GetAtoms()):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from .mol_tree import Vocab, DGLMolTree
from .chemutils import mol2dgl_dec, mol2dgl_enc

ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']

ATOM_FDIM_DEC = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM_DEC = 5
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import rdkit
import rdkit.Chem as Chem
import copy
import numpy as np
from dgl import DGLGraph
Expand Down
34 changes: 2 additions & 32 deletions examples/pytorch/model_zoo/chem/generative_models/jtnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from dgl import model_zoo
from torch.utils.data import DataLoader

import math, random, sys
import sys
import argparse
from collections import deque
import rdkit

from jtnn import *
Expand Down Expand Up @@ -42,8 +41,6 @@ def worker_init_fn(id_):
help="Coefficient of KL Divergence term")
parser.add_argument("-q", "--lr", dest="lr", default=1e-3,
help="Learning Rate")
parser.add_argument("-T", "--test", dest="test", action="store_true",
help="Add this flag to run test mode")
args = parser.parse_args()

dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=True)
Expand Down Expand Up @@ -131,35 +128,8 @@ def train():
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch))


def test():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)

# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
print(gt_smiles)
model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch)
tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
smiles = model.decode(tree_vec, mol_vec)
print(smiles)


if __name__ == '__main__':
if args.test:
test()
else:
train()
train()

print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total)
Expand Down
4 changes: 1 addition & 3 deletions python/dgl/model_zoo/chem/jtnn/chemutils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# pylint: disable=C0111, C0103, E1101, W0611, W0612, W0703, C0200, R1710
from collections import defaultdict

import rdkit
import rdkit.Chem as Chem
from rdkit.Chem.EnumerateStereoisomers import (EnumerateStereoisomers,
StereoEnumerationOptions)
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree

Expand Down
Loading

0 comments on commit 9df8cd3

Please sign in to comment.