Skip to content

Commit

Permalink
Merge branch 'master' into cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
jermainewang committed Oct 18, 2018
2 parents b963191 + ec4216d commit 6cbdf37
Show file tree
Hide file tree
Showing 16 changed files with 1,697 additions and 10 deletions.
13 changes: 13 additions & 0 deletions examples/pytorch/jtnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Junction Tree VAE - example for training
===

This is a direct modification from https://github.com/wengong-jin/icml18-jtnn

You need to have RDKit installed.

To run the model, use
```
python3 vaetrain_dgl.py
```
The script will automatically download the data, which is the same as the one in the
original repository.
7 changes: 7 additions & 0 deletions examples/pytorch/jtnn/jtnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .mol_tree import Vocab
from .jtnn_vae import DGLJTNNVAE
from .mpn import DGLMPN, mol2dgl
from .nnutils import create_var
from .datautils import JTNNDataset
from .chemutils import decode_stereo
from .line_profiler_integration import profile
334 changes: 334 additions & 0 deletions examples/pytorch/jtnn/jtnn/chemutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
import rdkit
import rdkit.Chem as Chem
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions

MST_MAX_WEIGHT = 100
MAX_NCAND = 2000

def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)

def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.Kekulize(mol)
return mol

def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)

def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))

dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]

chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))

return smiles3D

def sanitize(mol):
try:
smiles = get_smiles(mol)
mol = get_mol(smiles)
except Exception as e:
return None
return mol

def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom

def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
bt = bond.GetBondType()
new_mol.AddBond(a1, a2, bt)
return new_mol

def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) #We assume this is not None
return new_mol

def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []

cliques = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1,a2])

ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)

nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)

#Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2: continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2: continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []

cliques = [c for c in cliques if len(c) > 0]
nei_list = [[] for i in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)

#Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1,c2)] = 1
elif len(rings) > 2: #Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1,c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1,c2 = cnei[i],cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1,c2)] < len(inter):
edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction

edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()]
if len(edges) == 0:
return cliques, edges

#Compute Maximum Spanning Tree
row,col,data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) )
junc_tree = minimum_spanning_tree(clique_graph)
row,col = junc_tree.nonzero()
edges = [(row[i],col[i]) for i in range(len(row))]
return (cliques, edges)

def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()

#Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
else:
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])

def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
new_atom = copy_atom(atom)
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)

if nei_mol.GetNumBonds() == 0:
nei_atom = nei_mol.GetAtomWithIdx(0)
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
else:
for bond in nei_mol.GetBonds():
a1 = amap[bond.GetBeginAtom().GetIdx()]
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: #father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol

def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}

for nei_id,ctr_atom,nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom

ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()

#This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol,nei_idx = nei_node['mol'], nei_node['nid']
att_confs = []
black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]

if nei_mol.GetNumBonds() == 0: #neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _,atom_idx,_ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append( new_amap )

elif nei_mol.GetNumBonds() == 1: #neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom()

for atom in ctr_atoms:
#Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append( new_amap )
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append( new_amap )
else:
#intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
#Optimize if atom is carbon (other atoms may change valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append( new_amap )

#intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append( new_amap )

if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append( new_amap )
return att_confs

#Try rings first: Speed-Up
def enum_assemble_nx(graph, node_idx, neighbors, prev_nodes=[], prev_amap=[]):
node = graph.nodes[node_idx]
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors + prev_nodes if nei_node['mol'].GetNumAtoms() == 1]

def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
return
if depth == len(neighbors):
all_attach_confs.append(cur_amap)
return

nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(node['mol'], neighbors[:depth+1], prev_nodes, amap)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
smiles = get_smiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
candidates.append(amap)

if len(candidates) == 0:
return []

for new_amap in candidates:
search(new_amap, depth + 1)

search(prev_amap, 0)
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append( (smiles,cand_mol,amap) )

return candidates

#Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
cur_node = graph.nodes[cur_node_id]
fa_node = graph.nodes[fa_node_id] if fa_node_id is not None else None

fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []

children_id = [nei for nei in graph[cur_node_id] if graph.nodes[nei]['nid'] != fa_nid]
children = [graph.nodes[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors

cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(graph, cur_node_id, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return

cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_amap = cand_amap[label_idx]

for nei_id,ctr_atom,nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]

cur_mol = attach_mols_nx(cur_mol, children, [], global_amap) #father is already attached
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap, label_amap, nei_node_id, cur_node_id)
33 changes: 33 additions & 0 deletions examples/pytorch/jtnn/jtnn/datautils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from torch.utils.data import Dataset
import numpy as np

import dgl
from dgl.data.utils import download, extract_archive, get_download_dir

_url = 'https://www.dropbox.com/s/4ypr0e0abcbsvoh/jtnn.zip?dl=1'

class JTNNDataset(Dataset):
def __init__(self, data, vocab):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
download(_url, path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
print('Loading data...')
data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f]
self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab)
print('Loading finished.')
print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
from .mol_tree_nx import DGLMolTree
smiles = self.data[idx]
mol_tree = DGLMolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
return mol_tree
Loading

0 comments on commit 6cbdf37

Please sign in to comment.