Skip to content

Commit

Permalink
rebuild DAG
Browse files Browse the repository at this point in the history
  • Loading branch information
miaecle committed Mar 28, 2017
1 parent 8233320 commit 06ecfb8
Show file tree
Hide file tree
Showing 7 changed files with 632 additions and 1 deletion.
44 changes: 44 additions & 0 deletions deepchem/models/tests/test_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,50 @@ def test_DTNN_multitask_regression_overfit(self):

assert scores[regression_metric.name] > .9

def test_DAG_singletask_regression_overfit(self):
"""Test DAG regressor multitask overfits tiny data."""
np.random.seed(123)
tf.set_random_seed(123)
n_tasks = 1

# Load mini log-solubility dataset.
featurizer = dc.feat.ConvMolFeaturizer()
tasks = ["outcome"]
input_file = os.path.join(self.current_dir, "example_regression.csv")
loader = dc.data.CSVLoader(
tasks=tasks, smiles_field="smiles", featurizer=featurizer)
dataset = loader.featurize(input_file)

regression_metric = dc.metrics.Metric(
dc.metrics.pearson_r2_score, task_averager=np.mean)

n_feat = 75
batch_size = 10

graph = dc.nn.SequentialDAGGraph(
n_feat, batch_size=batch_size, max_atoms=50)
graph.add(dc.nn.DAGLayer(30, n_feat, max_atoms=50))
graph.add(dc.nn.DAGGather(max_atoms=50))

model = dc.models.MultitaskGraphRegressor(
graph,
n_tasks,
n_feat,
batch_size=batch_size,
learning_rate=0.005,
learning_rate_decay_time=1000,
optimizer_type="adam",
beta1=.9,
beta2=.999)

# Fit trained model
model.fit(dataset, nb_epoch=50)
model.save()
# Eval model on train
scores = model.evaluate(dataset, [regression_metric])

assert scores[regression_metric.name] > .9

def test_siamese_singletask_classification_overfit(self):
"""Test siamese singletask model overfits tiny data."""
np.random.seed(123)
Expand Down
35 changes: 34 additions & 1 deletion deepchem/models/tf_new_models/graph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import tensorflow as tf
from deepchem.nn.layers import GraphGather
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology, DAGGraphTopology


class SequentialGraph(object):
Expand Down Expand Up @@ -129,6 +129,39 @@ def add(self, layer):
self.layers.append(layer)


class SequentialDAGGraph(SequentialGraph):
"""SequentialGraph for DAG models
"""

def __init__(self, n_feat, batch_size=50, max_atoms=50):
"""
Parameters
----------
n_feat: int
Number of features per atom.
batch_size: int, optional(default=50)
Number of molecules in a batch
max_atoms: int, optional(default=50)
Maximum number of atoms in a molecule, should be defined based on dataset
"""
self.graph = tf.Graph()
with self.graph.as_default():
self.graph_topology = DAGGraphTopology(
n_feat, batch_size, max_atoms=max_atoms)
self.output = self.graph_topology.get_atom_features_placeholder()
self.layers = []

def add(self, layer):
"""Adds a new layer to model."""
with self.graph.as_default():
if type(layer).__name__ in ['DAGLayer']:
self.output = layer([self.output] +
self.graph_topology.get_topology_placeholders())
else:
self.output = layer(self.output)
self.layers.append(layer)


class SequentialSupportGraph(object):
"""An analog of Keras Sequential model for test/support models."""

Expand Down
173 changes: 173 additions & 0 deletions deepchem/models/tf_new_models/graph_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,176 @@ def gauss_expand(distance, n_distance, distance_min, distance_max):
steps = np.array([distance_min + i * step_size for i in range(n_distance)])
distance_vector = np.exp(-np.square(distance - steps) / (2 * step_size**2))
return distance_vector


class DAGGraphTopology(GraphTopology):
"""GraphTopology for DAG models
"""

def __init__(self, n_feat, batch_size, name='topology', max_atoms=50):

self.n_feat = n_feat
self.name = name
self.max_atoms = max_atoms
self.batch_size = batch_size
self.atom_features_placeholder = tf.placeholder(
dtype='float32',
shape=(self.batch_size * self.max_atoms, self.n_feat),
name=self.name + '_atom_features')

self.parents_placeholder = tf.placeholder(
dtype='int32',
shape=(self.batch_size * self.max_atoms, self.max_atoms,
self.max_atoms),
# molecule * atom(graph) => step => features
name=self.name + '_parents')

self.calculation_orders_placeholder = tf.placeholder(
dtype='int32',
shape=(self.batch_size * self.max_atoms, self.max_atoms),
# molecule * atom(graph) => step
name=self.name + '_orders')

self.membership_placeholder = tf.placeholder(
dtype='int32',
shape=(self.batch_size * self.max_atoms),
name=self.name + '_membership')

# Define the list of tensors to be used as topology
self.topology = [
self.parents_placeholder, self.calculation_orders_placeholder,
self.membership_placeholder
]

self.inputs = [self.atom_features_placeholder]
self.inputs += self.topology

def get_parents_placeholder(self):
return self.parents_placeholder

def get_calculation_orders_placeholder(self):
return self.calculation_orders_placeholder

def batch_to_feed_dict(self, batch):
"""Converts the current batch of mol_graphs into tensorflow feed_dict.
Assigns the graph information in array of ConvMol objects to the
placeholders tensors for DAG models
params
------
batch : np.ndarray
Array of ConvMol objects
returns
-------
feed_dict : dict
Can be merged with other feed_dicts for input into tensorflow
"""
# Merge mol conv objects

atoms_per_mol = [mol.get_num_atoms() for mol in batch]
n_atom_features = batch[0].get_atom_features().shape[1]
membership = np.concatenate(
[
np.array([1] * n_atoms + [0] * (self.max_atoms - n_atoms))
for i, n_atoms in enumerate(atoms_per_mol)
],
axis=0)

atoms_all = []
parents_all = []
calculation_orders = []
for idm, mol in enumerate(batch):
atom_features_padded = np.concatenate(
[
mol.get_atom_features(), np.zeros(
(self.max_atoms - atoms_per_mol[idm], n_atom_features))
],
axis=0)
# padding atom features vector of each molecule with 0
atoms_all.append(atom_features_padded)

parents = self.UG_to_DAG(mol)
# ConvMol objects input here should have gone through the DAG Transformer
assert len(parents) == atoms_per_mol[idm]
parents_all.extend(parents[:])
parents_all.extend([
self.max_atoms * np.ones((self.max_atoms, self.max_atoms), dtype=int)
for i in range(self.max_atoms - atoms_per_mol[idm])
])
# padding with max_atoms
for parent in parents:
calculation_orders.append(self.indice_changing(parent[:, 0], idm))
# change the indice from current molecule to batch of molecules
calculation_orders.extend([
self.batch_size * self.max_atoms * np.ones(
(self.max_atoms,), dtype=int)
for i in range(self.max_atoms - atoms_per_mol[idm])
])
# padding with batch_size * max_atoms

atoms_all = np.concatenate(atoms_all, axis=0)
parents_all = np.stack(parents_all, axis=0)
calculation_orders = np.stack(calculation_orders, axis=0)
atoms_dict = {
self.atom_features_placeholder: atoms_all,
self.membership_placeholder: membership,
self.parents_placeholder: parents_all,
self.calculation_orders_placeholder: calculation_orders
}

return atoms_dict

def indice_changing(self, indice, n_mol):
output = np.zeros_like(indice)
for ide, element in enumerate(indice):
if element < self.max_atoms:
output[ide] = element + n_mol * self.max_atoms
else:
output[ide] = self.batch_size * self.max_atoms
return output

def UG_to_DAG(self, sample):
parents = []
UG = sample.get_adjacency_list()
n_atoms = sample.get_num_atoms()
max_atoms = self.max_atoms
for count in range(n_atoms):
DAG = []
parent = [[] for i in range(n_atoms)]
current_atoms = [count]
# first element is current atom
atoms_indicator = np.ones((n_atoms,))
# if is been included in the graph
atoms_indicator[count] = 0
radial = 0
while np.sum(atoms_indicator) > 0:
if radial > n_atoms:
break # molecules with two separate ions may stuck here
next_atoms = []
for current_atom in current_atoms:
for atom_adj in UG[current_atom]:
# atoms connected to current_atom
if atoms_indicator[atom_adj] > 0:
DAG.append((current_atom, atom_adj))
atoms_indicator[atom_adj] = 0
# tagging for included atoms
next_atoms.append(atom_adj)
current_atoms = next_atoms
# into next step, finding atoms connected with one more bond
radial = radial + 1
for edge in reversed(DAG):
parent[edge[0]].append(edge[1])
parent[edge[0]].extend(parent[edge[1]])
# adding parents
for ids, atom in enumerate(parent):
parent[ids].insert(0, ids)
parent = sorted(parent, key=len)
for ids, atom in enumerate(parent):
n_par = len(atom)
parent[ids].extend([max_atoms for i in range(max_atoms - n_par)])
while len(parent) < max_atoms:
parent.insert(0, [max_atoms] * max_atoms)
parents.append(np.array(parent))
return parents
4 changes: 4 additions & 0 deletions deepchem/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from deepchem.nn.layers import DTNNEmbedding
from deepchem.nn.layers import DTNNStep
from deepchem.nn.layers import DTNNGather
from deepchem.nn.layers import DAGLayer
from deepchem.nn.layers import DAGGather

from deepchem.nn.model_ops import weight_decay
from deepchem.nn.model_ops import optimizer
Expand All @@ -28,6 +30,8 @@

from deepchem.models.tf_new_models.graph_topology import GraphTopology
from deepchem.models.tf_new_models.graph_topology import DTNNGraphTopology
from deepchem.models.tf_new_models.graph_topology import DAGGraphTopology
from deepchem.models.tf_new_models.graph_models import SequentialGraph
from deepchem.models.tf_new_models.graph_models import SequentialDTNNGraph
from deepchem.models.tf_new_models.graph_models import SequentialDAGGraph
from deepchem.models.tf_new_models.graph_models import SequentialSupportGraph
Loading

0 comments on commit 06ecfb8

Please sign in to comment.