Skip to content

Commit

Permalink
[Model] Junction Tree VAE update (dmlc#157)
Browse files Browse the repository at this point in the history
* cherry picking optimization from jtnn

* adding official code.  TODO: fix DGLMolTree

* updating to current api.  vae test still failing

* reverting to list stacking

* reverting to list stacking

* cleaning x flags (stupid windows)

* cleaning x flags (stupid windows)

* adding stats

* optimization

* updating dgl stats

* update again

* more optimization

* looks like computation is faster

* removing profiling code

* cleaning obsolete code

* remove comparison warning

* readme update

* official implementation got a lot faster

* minor fixes

* unbatch by slicing frames

* working around unbatch

* reduce pack

* oops

* support frame read/write with slices

* reverting back to readout as unbatch-by-slicing slows down backward

* reverting to unbatch by splitting; slicing is unfriendly to backward

* replacing lru cache with static object factory

* cherry picking optimization from jtnn

* unbatch by slicing frames

* reduce pack

* oops

* support frame read/write with slices

* reverting to unbatch by splitting; slicing is unfriendly to backward

* replacing lru cache with static object factory

* replacing Scheme object with namedtuple

* forgot the find edges interface

* subclassing namedtuple

* updating to the latest api spec

* bugfix

* bfs with edges

* dfs toy test case

* clean up

* style fix

* bugfix

* update to latest api; include traversal

* replacing with readout

* simplify decoder

* oops

* cleanup

* reducing number of sets

* more speed up

* profile results

* random fixes

* fixing tvmarray handling incontiguous dlpack input

* fancier dataloader

* fix a potential context mismatch

* todo: support pickling or using scipy in multiprocessing load

* pickling support

* resorting to suggested way of pickling

* custom attribute pickling check

* working around a weird pytorch pickling bug

* including partial frame case

* enabling multiprocessing dataloader

* pickling everything now

* really works

* oops

* updated profiling results

* cleanup

* fix as requested

* cleaning random blank lines

* removing profiler outputs

* starting decoding

* testing, WIP

* tree decoding

* graph decoding, WIP

* graph decoding works

* oops

* fixing legacy apis

* trimming number of candidate structures

* sampling cleanups

* removing comparison test

* updated description
  • Loading branch information
BarclayII authored Dec 2, 2018
1 parent 4682b76 commit ac932c6
Show file tree
Hide file tree
Showing 12 changed files with 937 additions and 419 deletions.
13 changes: 13 additions & 0 deletions examples/pytorch/jtnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,16 @@ python3 vaetrain_dgl.py
```
The script will automatically download the data, which is the same as the one in the
original repository.

To disable CUDA, run with `NOCUDA` variable set:
```
NOCUDA=1 python3 vaetrain_dgl.py
```

To decode for new molecules, run
```
python3 vaetrain_dgl.py -T
```

Currently, decoding involves encoding a training example, sampling from the posterior
distribution, and decoding a molecule from that.
7 changes: 3 additions & 4 deletions examples/pytorch/jtnn/jtnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .mol_tree import Vocab
from .jtnn_vae import DGLJTNNVAE
from .mpn import DGLMPN, mol2dgl
from .nnutils import create_var
from .datautils import JTNNDataset
from .mpn import DGLMPN
from .nnutils import create_var, cuda
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
from .line_profiler_integration import profile
13 changes: 6 additions & 7 deletions examples/pytorch/jtnn/jtnn/chemutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@ def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
return att_confs

#Try rings first: Speed-Up
def enum_assemble_nx(graph, node_idx, neighbors, prev_nodes=[], prev_amap=[]):
node = graph.nodes[node_idx]
def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors + prev_nodes if nei_node['mol'].GetNumAtoms() == 1]

Expand Down Expand Up @@ -301,21 +300,21 @@ def search(cur_amap, depth):

#Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
cur_node = graph.nodes[cur_node_id]
fa_node = graph.nodes[fa_node_id] if fa_node_id is not None else None
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None

fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []

children_id = [nei for nei in graph[cur_node_id] if graph.nodes[nei]['nid'] != fa_nid]
children = [graph.nodes[nei] for nei in children_id]
children_id = [nei for nei in graph[cur_node_id] if graph.nodes_dict[nei]['nid'] != fa_nid]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors

cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(graph, cur_node_id, neighbors, prev_nodes, cur_amap)
cands = enum_assemble_nx(graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return

Expand Down
195 changes: 192 additions & 3 deletions examples/pytorch/jtnn/jtnn/datautils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
import torch
from torch.utils.data import Dataset
import numpy as np

import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
from .mol_tree_nx import DGLMolTree
from .mol_tree import Vocab

from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import mol2dgl_single as mol2dgl_dec

_url = 'https://www.dropbox.com/s/4ypr0e0abcbsvoh/jtnn.zip?dl=1'

def _unpack_field(examples, field):
return [e[field] for e in examples]

def _set_node_id(mol_tree, vocab):
wid = []
for i, node in enumerate(mol_tree.nodes_dict):
mol_tree.nodes_dict[node]['idx'] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles']))

return wid

class JTNNDataset(Dataset):
def __init__(self, data, vocab):
def __init__(self, data, vocab, training=True):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path)
Expand All @@ -20,14 +37,186 @@ def __init__(self, data, vocab):
print('Loading finished.')
print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file)
self.training = training
self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
from .mol_tree_nx import DGLMolTree
smiles = self.data[idx]
mol_tree = DGLMolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
return mol_tree

wid = _set_node_id(mol_tree, self.vocab)

# prebuild the molecule graph
mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)

result = {
'mol_tree': mol_tree,
'mol_graph': mol_graph,
'atom_x_enc': atom_x_enc,
'bond_x_enc': bond_x_enc,
'wid': wid,
}

if not self.training:
return result

# prebuild the candidate graph list
cands = []
for node_id, node in mol_tree.nodes_dict.items():
# fill in ground truth
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])

if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, mol_tree, node_id)
for cand in node['cand_mols']])
if len(cands) > 0:
cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
else:
cand_graphs = []
atom_x_dec = torch.zeros(0, atom_x_enc.shape[1])
bond_x_dec = torch.zeros(0, bond_x_enc.shape[1])
tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0, 2).long()

# prebuild the stereoisomers
cands = mol_tree.stereo_cands
if len(cands) > 1:
if mol_tree.smiles3D not in cands:
cands.append(mol_tree.smiles3D)

stereo_graphs = [mol2dgl_enc(c) for c in cands]
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \
zip(*stereo_graphs)
stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
else:
stereo_cand_graphs = []
stereo_atom_x_enc = torch.zeros(0, atom_x_enc.shape[1])
stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
stereo_cand_label = []

result.update({
'cand_graphs': cand_graphs,
'atom_x_dec': atom_x_dec,
'bond_x_dec': bond_x_dec,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graphs': stereo_cand_graphs,
'stereo_atom_x_enc': stereo_atom_x_enc,
'stereo_bond_x_enc': stereo_bond_x_enc,
'stereo_cand_label': stereo_cand_label,
})

return result

class JTNNCollator(object):
def __init__(self, vocab, training):
self.vocab = vocab
self.training = training

@staticmethod
def _batch_and_set(graphs, atom_x, bond_x, flatten):
if flatten:
graphs = [g for f in graphs for g in f]
graph_batch = dgl.batch(graphs)
graph_batch.ndata['x'] = atom_x
graph_batch.edata.update({
'x': bond_x,
'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
})
return graph_batch

def __call__(self, examples):
# get list of trees
mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid')
for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.ndata['wid'] = torch.LongTensor(_wid)

# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
mol_graphs = _unpack_field(examples, 'mol_graph')
atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc'))
mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)

result = {
'mol_trees': mol_trees,
'mol_graph_batch': mol_graph_batch,
}

if not self.training:
return result

# batch candidate graphs
cand_graphs = _unpack_field(examples, 'cand_graphs')
cand_batch_idx = []
atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec'))
tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e')
tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e')
tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n')

n_graph_nodes = 0
n_tree_nodes = 0
for i in range(len(cand_graphs)):
tree_mess_tgt_e[i] += n_graph_nodes
tree_mess_src_e[i] += n_tree_nodes
tree_mess_tgt_n[i] += n_graph_nodes
n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
n_tree_nodes += mol_trees[i].number_of_nodes()
cand_batch_idx.extend([i] * len(cand_graphs[i]))
tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
tree_mess_src_e = torch.cat(tree_mess_src_e)
tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)

cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True)

# batch stereoisomers
stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs')
atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc'))
stereo_cand_batch_idx = []
for i in range(len(stereo_cand_graphs)):
stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))

if len(stereo_cand_batch_idx) > 0:
stereo_cand_labels = [
(label, length)
for ex in _unpack_field(examples, 'stereo_cand_label')
for label, length in ex
]
stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
stereo_cand_graph_batch = self._batch_and_set(
stereo_cand_graphs, atom_x, bond_x, True)
else:
stereo_cand_labels = []
stereo_cand_lengths = []
stereo_cand_graph_batch = None
stereo_cand_batch_idx = []

result.update({
'cand_graph_batch': cand_graph_batch,
'cand_batch_idx': cand_batch_idx,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graph_batch': stereo_cand_graph_batch,
'stereo_cand_batch_idx': stereo_cand_batch_idx,
'stereo_cand_labels': stereo_cand_labels,
'stereo_cand_lengths': stereo_cand_lengths,
})

return result
Loading

0 comments on commit ac932c6

Please sign in to comment.