Skip to content

Commit

Permalink
Merge pull request deepchem#523 from rbharath/nbr_list_layer
Browse files Browse the repository at this point in the history
Callable Layers and NeighborList layers
  • Loading branch information
rbharath authored Apr 26, 2017
2 parents 3121739 + 3a666fd commit 6bbe6a4
Show file tree
Hide file tree
Showing 10 changed files with 1,662 additions and 146 deletions.
7 changes: 5 additions & 2 deletions deepchem/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,8 +1049,11 @@ class Databag(object):
A utility class to iterate through multiple datasets together.
"""

def __init__(self):
self.datasets = dict()
def __init__(self, datasets=None):
if datasets is None:
self.datasets = dict()
else:
self.datasets = datasets

def add_dataset(self, key, dataset):
self.datasets[key] = dataset
Expand Down
99 changes: 67 additions & 32 deletions deepchem/models/tensorgraph/graph_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@
from deepchem.nn import model_ops

from deepchem.models.tensorgraph.layers import Layer
from deepchem.models.tensorgraph.layers import convert_to_layers


class Combine_AP(Layer):

def __init__(self, **kwargs):
super(Combine_AP, self).__init__(**kwargs)

def _create_tensor(self):
A = self.in_layers[0].out_tensor
P = self.in_layers[1].out_tensor
def create_tensor(self, in_layers=None):
if in_layers is None:
in_layers = self.in_layers
in_layers = convert_to_layers(in_layers)
A = in_layers[0].out_tensor
P = in_layers[1].out_tensor
self.out_tensor = [A, P]


Expand All @@ -35,8 +39,11 @@ class Separate_AP(Layer):
def __init__(self, **kwargs):
super(Separate_AP, self).__init__(**kwargs)

def _create_tensor(self):
self.out_tensor = self.in_layers[0].out_tensor[0]
def create_tensor(self, in_layers=None):
if in_layers is None:
in_layers = self.in_layers
in_layers = convert_to_layers(in_layers)
self.out_tensor = in_layers[0].out_tensor[0]


class WeaveLayer(Layer):
Expand Down Expand Up @@ -140,17 +147,21 @@ def build(self):
self.trainable_weights.extend(
[self.W_AP, self.b_AP, self.W_PP, self.b_PP, self.W_P, self.b_P])

def _create_tensor(self):
def create_tensor(self, in_layers=None):
""" description and explanation refer to deepchem.nn.WeaveLayer
parent layers: [atom_features, pair_features], pair_split, atom_to_pair
"""
if in_layers is None:
in_layers = self.in_layers
in_layers = convert_to_layers(in_layers)

self.build()

atom_features = self.in_layers[0].out_tensor[0]
pair_features = self.in_layers[0].out_tensor[1]
atom_features = in_layers[0].out_tensor[0]
pair_features = in_layers[0].out_tensor[1]

pair_split = self.in_layers[1].out_tensor
atom_to_pair = self.in_layers[2].out_tensor
pair_split = in_layers[1].out_tensor
atom_to_pair = in_layers[2].out_tensor

AA = tf.matmul(atom_features, self.W_AA) + self.b_AA
AA = self.activation(AA)
Expand Down Expand Up @@ -230,13 +241,17 @@ def build(self):
else:
self.trainable_weights = None

def _create_tensor(self):
def create_tensor(self, in_layers=None):
""" description and explanation refer to deepchem.nn.WeaveGather
parent layers: atom_features, atom_split
"""
if in_layers is None:
in_layers = self.in_layers
in_layers = convert_to_layers(in_layers)

self.build()
outputs = self.in_layers[0].out_tensor
atom_split = self.in_layers[1].out_tensor
outputs = in_layers[0].out_tensor
atom_split = in_layers[1].out_tensor

if self.gaussian_expand:
outputs = self.gaussian_histogram(outputs)
Expand Down Expand Up @@ -297,12 +312,16 @@ def build(self):
[self.periodic_table_length, self.n_embedding])
self.trainable_weights = [self.embedding_list]

def _create_tensor(self):
def create_tensor(self, in_layers=None):
"""description and explanation refer to deepchem.nn.DTNNEmbedding
parent layers: atom_number
"""
if in_layers is None:
in_layers = self.in_layers
in_layers = convert_to_layers(in_layers)

self.build()
atom_number = self.in_layers[0].out_tensor
atom_number = in_layers[0].out_tensor
atom_features = tf.nn.embedding_lookup(self.embedding_list, atom_number)
self.out_tensor = atom_features

Expand Down Expand Up @@ -356,15 +375,19 @@ def build(self):
self.W_cf, self.W_df, self.W_fc, self.b_cf, self.b_df
]

def _create_tensor(self):
def create_tensor(self, in_layers=None):
"""description and explanation refer to deepchem.nn.DTNNStep
parent layers: atom_features, distance, distance_membership_i, distance_membership_j
"""
if in_layers is None:
in_layers = self.in_layers
in_layers = convert_to_layers(in_layers)

self.build()
atom_features = self.in_layers[0].out_tensor
distance = self.in_layers[1].out_tensor
distance_membership_i = self.in_layers[2].out_tensor
distance_membership_j = self.in_layers[3].out_tensor
atom_features = in_layers[0].out_tensor
distance = in_layers[1].out_tensor
distance_membership_i = in_layers[2].out_tensor
distance_membership_j = in_layers[3].out_tensor
distance_hidden = tf.matmul(distance, self.W_df) + self.b_df
atom_features_hidden = tf.matmul(atom_features, self.W_cf) + self.b_cf
outputs = tf.multiply(distance_hidden,
Expand Down Expand Up @@ -438,13 +461,17 @@ def build(self):

self.trainable_weights = self.W_list + self.b_list

def _create_tensor(self):
def create_tensor(self, in_layers=None):
"""description and explanation refer to deepchem.nn.DTNNGather
parent layers: atom_features, atom_membership
"""
if in_layers is None:
in_layers = self.in_layers
in_layers = convert_to_layers(in_layers)

self.build()
output = self.in_layers[0].out_tensor
atom_membership = self.in_layers[1].out_tensor
output = in_layers[0].out_tensor
atom_membership = in_layers[1].out_tensor
for i, W in enumerate(self.W_list):
output = tf.matmul(output, W) + self.b_list[i]
output = self.activation(output)
Expand Down Expand Up @@ -521,22 +548,26 @@ def build(self):

self.trainable_weights = self.W_list + self.b_list

def _create_tensor(self):
def create_tensor(self, in_layers=None):
"""description and explanation refer to deepchem.nn.DAGLayer
parent layers: atom_features, parents, calculation_orders, calculation_masks, n_atoms
"""
if in_layers is None:
in_layers = self.in_layers
in_layers = convert_to_layers(in_layers)

# Add trainable weights
self.build()

atom_features = self.in_layers[0].out_tensor
atom_features = in_layers[0].out_tensor
# each atom corresponds to a graph, which is represented by the `max_atoms*max_atoms` int32 matrix of index
# each gragh include `max_atoms` of steps(corresponding to rows) of calculating graph features
parents = self.in_layers[1].out_tensor
parents = in_layers[1].out_tensor
# target atoms for each step: (batch_size*max_atoms) * max_atoms
calculation_orders = self.in_layers[2].out_tensor
calculation_masks = self.in_layers[3].out_tensor
calculation_orders = in_layers[2].out_tensor
calculation_masks = in_layers[3].out_tensor

n_atoms = self.in_layers[4].out_tensor
n_atoms = in_layers[4].out_tensor
# initialize graph features for each graph
graph_features_initial = tf.zeros((self.max_atoms * self.batch_size,
self.max_atoms + 1, self.n_graph_feat))
Expand Down Expand Up @@ -655,16 +686,20 @@ def build(self):

self.trainable_weights = self.W_list + self.b_list

def _create_tensor(self):
def create_tensor(self, in_layers=None):
"""description and explanation refer to deepchem.nn.DAGGather
parent layers: atom_features, membership
"""
if in_layers is None:
in_layers = self.in_layers
in_layers = convert_to_layers(in_layers)

# Add trainable weights
self.build()

# Extract atom_features
atom_features = self.in_layers[0].out_tensor
membership = self.in_layers[1].out_tensor
atom_features = in_layers[0].out_tensor
membership = in_layers[1].out_tensor
# Extract atom_features
graph_features = tf.segment_sum(atom_features, membership)
# sum all graph outputs
Expand Down
Loading

0 comments on commit 6bbe6a4

Please sign in to comment.