Skip to content

Commit

Permalink
Merge pull request deepchem#850 from peastman/copy
Browse files Browse the repository at this point in the history
Changes to support copying pieces of models
  • Loading branch information
rbharath authored Sep 26, 2017
2 parents 5019ad5 + fe274ab commit a260b0a
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 22 deletions.
129 changes: 116 additions & 13 deletions deepchem/models/tensorgraph/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
import string
from collections import Sequence
from copy import deepcopy

import tensorflow as tf
import numpy as np
Expand All @@ -23,6 +24,7 @@ def __init__(self, in_layers=None, **kwargs):
self.in_layers = in_layers
self.op_type = "gpu"
self.variable_scope = ''
self.variable_values = None
self.rnn_initial_states = []
self.rnn_final_states = []
self.rnn_zero_states = []
Expand Down Expand Up @@ -115,9 +117,22 @@ def _record_variable_scope(self, local_scope):
else:
self.variable_scope = local_scope

def set_variable_initial_values(self, values):
"""Set the initial values of all variables.
This takes a list, which contains the initial values to use for all of
this layer's values (in the same order retured by
TensorGraph.get_layer_variables()). When this layer is used in a
TensorGraph, it will automatically initialize each variable to the value
specified in the list. Note that some layers also have separate mechanisms
for specifying variable initializers; this method overrides them. The
purpose of this method is to let a Layer object represent a pre-trained
layer, complete with trained values for its variables."""
self.variable_values = values

def set_summary(self, summary_op, summary_description=None, collections=None):
"""Annotates a tensor with a tf.summary operation
Collects data from self.out_tensor by default but can be changed by setting
Collects data from self.out_tensor by default but can be changed by setting
self.tb_input to another tensor in create_tensor
Expand Down Expand Up @@ -156,6 +171,64 @@ def add_summary_to_tg(self):
elif self.summary_op == 'histogram':
tf.summary.histogram(self.name, self.tb_input, self.collections)

def copy(self, replacements={}, variables_graph=None):
"""Duplicate this Layer and all its inputs.
This creates and returns a clone of this layer. It also recursively calls
copy() on all of this layer's inputs to clone the entire hierarchy of layers.
In the process, you can optionally tell it to replace particular layers with
specific existing ones. For example, you can clone a stack of layers, while
connecting the topmost ones to different inputs.
For example, consider a stack of dense layers that depend on an input:
>>> input = Feature(shape=(None, 100))
>>> dense1 = Dense(100, in_layers=input)
>>> dense2 = Dense(100, in_layers=dense1)
>>> dense3 = Dense(100, in_layers=dense2)
The following will clone all three dense layers, but not the input layer.
Instead, the input to the first dense layer will be a different layer
specified in the replacements map.
>>> replacements = {input: new_input}
>>> dense3_copy = dense3.copy(replacements)
Parameters
----------
replacements: map
specifies existing layers, and the layers to replace them with (instead of
cloning them). This argument serves two purposes. First, you can pass in
a list of replacements to control which layers get cloned. In addition,
as each layer is cloned, it is added to this map. On exit, it therefore
contains a complete record of all layers that were copied, and a reference
to the copy of each one.
variables_graph: TensorGraph
an optional TensorGraph from which to take variables. If this is specified,
the current value of each variable in each layer is recorded, and the copy
has that value specified as its initial value. This allows a piece of a
pre-trained model to be copied to another model.
"""
if self in replacements:
return replacements[self]
copied_inputs = [
layer.copy(replacements, variables_graph) for layer in self.in_layers
]
saved_inputs = self.in_layers
self.in_layers = []
saved_tensors = self.none_tensors()
copy = deepcopy(self)
self.in_layers = saved_inputs
self.set_tensors(saved_tensors)
copy.in_layers = copied_inputs
if variables_graph is not None:
variables = variables_graph.get_layer_variables(self)
if len(variables) > 0:
with variables_graph._get_tf("Graph").as_default():
values = variables_graph.session.run(variables)
copy.set_variable_initial_values(values)
return copy

def _as_graph_element(self):
if '_as_graph_element' in dir(self.out_tensor):
return self.out_tensor._as_graph_element()
Expand Down Expand Up @@ -918,6 +991,36 @@ def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
return out_tensor


class StopGradient(Layer):
"""Block the flow of gradients.
This layer copies its input directly to its output, but reports that all
gradients of its output are zero. This means, for example, that optimizers
will not try to optimize anything "upstream" of this layer.
For example, suppose you have pre-trained a stack of layers to perform a
calculation. You want to use the result of that calculation as the input to
another layer, but because they are already pre-trained, you do not want the
optimizer to modify them. You can wrap the output in a StopGradient layer,
then use that as the input to the next layer."""

def __init__(self, in_layers=None, **kwargs):
super(StopGradient, self).__init__(in_layers, **kwargs)
try:
self._shape = tuple(self.in_layers[0].shape)
except:
pass

def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
inputs = self._get_input_tensors(in_layers)
if len(inputs) > 1:
raise ValueError("Only one layer supported.")
out_tensor = tf.stop_gradient(inputs[0])
if set_tensors:
self.out_tensor = out_tensor
return out_tensor


def _max_dimension(x, y):
if x is None:
return y
Expand Down Expand Up @@ -1724,14 +1827,14 @@ def __init__(self,
Dimensionality of output vectors.
input_dim: int
Dimensionality of input vectors.
init_fn: object
TensorFlow initialization to use for W.
inner_init_fn: object
TensorFlow initialization to use for U.
activation_fn: object
TensorFlow activation to use for output.
inner_activation_fn: object
TensorFlow activation to use for inner steps.
init_fn: object
TensorFlow initialization to use for W.
inner_init_fn: object
TensorFlow initialization to use for U.
activation_fn: object
TensorFlow activation to use for output.
inner_activation_fn: object
TensorFlow activation to use for inner steps.
"""

super(LSTMStep, self).__init__(**kwargs)
Expand Down Expand Up @@ -1787,7 +1890,7 @@ def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
Returns
-------
list
Returns h, [h + c]
Returns h, [h + c]
"""
activation = self.activation
inner_activation = self.inner_activation
Expand Down Expand Up @@ -1826,7 +1929,7 @@ def _cosine_dist(x, y):
x: tf.Tensor
Input Tensor
y: tf.Tensor
Input Tensor
Input Tensor
"""
denom = (
model_ops.sqrt(model_ops.sum(tf.square(x)) * model_ops.sum(tf.square(y)))
Expand Down Expand Up @@ -2911,7 +3014,7 @@ def distance_matrix(self, D):
def AlphaShare(in_layers=None, **kwargs):
"""
This method should be used when constructing AlphaShare layers from Sluice Networks
Parameters
----------
in_layers: list of Layers or tensors
Expand Down Expand Up @@ -2950,7 +3053,7 @@ class AlphaShareLayer(Layer):
Returns
-------
out_tensor: a tensor with shape [len(in_layers), x, y] where x, y were the original layer dimensions
out_tensor should be fed into LayerSplitter
out_tensor should be fed into LayerSplitter
Distance matrix.
"""

Expand Down
19 changes: 13 additions & 6 deletions deepchem/models/tensorgraph/tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ def create_feed_dict():
self.session.run(tf.global_variables_initializer())
if restore:
self.restore()
else:
# Initialize variables that have pre-trained values.
for layer in self.layers.values():
if layer.variable_values is not None:
variables = self.get_layer_variables(layer)
for var, val in zip(variables, layer.variable_values):
self.session.run(var.assign(val))
avg_loss, n_batches = 0.0, 0.0
coord = tf.train.Coordinator()
n_samples = 0
Expand Down Expand Up @@ -330,11 +337,11 @@ def predict_on_batch(self, X, transformers=[], outputs=None):
"""Generates predictions for input samples, processing samples in a batch.
Parameters
----------
----------
X: ndarray
the input data, as a Numpy array.
transformers: List
List of dc.trans.Transformers
List of dc.trans.Transformers
Returns
-------
Expand All @@ -348,11 +355,11 @@ def predict_proba_on_batch(self, X, transformers=[], outputs=None):
"""Generates predictions for input samples, processing samples in a batch.
Parameters
----------
----------
X: ndarray
the input data, as a Numpy array.
transformers: List
List of dc.trans.Transformers
List of dc.trans.Transformers
Returns
-------
Expand All @@ -370,7 +377,7 @@ def predict(self, dataset, transformers=[], outputs=None):
Dataset to make prediction on
transformers: list
List of dc.trans.Transformers.
outputs: object
outputs: object
If outputs is None, then will assume outputs = self.outputs[0] (single
output). If outputs is a Layer/Tensor, then will evaluate and return as a
single ndarray. If outputs is a list of Layers/Tensors, will return a list
Expand All @@ -391,7 +398,7 @@ def predict_proba(self, dataset, transformers=[], outputs=None):
Dataset to make prediction on
transformers: list
List of dc.trans.Transformers.
outputs: object
outputs: object
If outputs is None, then will assume outputs = self.outputs[0] (single
output). If outputs is a Layer/Tensor, then will evaluate and return as a
single ndarray. If outputs is a list of Layers/Tensors, will return a list
Expand Down
11 changes: 11 additions & 0 deletions deepchem/models/tensorgraph/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from deepchem.models.tensorgraph.layers import SluiceLoss
from deepchem.models.tensorgraph.layers import SoftMax
from deepchem.models.tensorgraph.layers import SoftMaxCrossEntropy
from deepchem.models.tensorgraph.layers import StopGradient
from deepchem.models.tensorgraph.layers import TensorWrapper
from deepchem.models.tensorgraph.layers import TimeSeriesDense
from deepchem.models.tensorgraph.layers import ToFloat
Expand Down Expand Up @@ -241,6 +242,16 @@ def test_variable(self):
sess.run(tf.global_variables_initializer())
assert np.array_equal(value, out_tensor.eval())

def test_stop_gradient(self):
"""Test that StopGradient can be invoked."""
batch_size = 10
n_features = 5
in_tensor = np.random.rand(batch_size, n_features)
with self.test_session() as sess:
in_tensor = tf.convert_to_tensor(in_tensor, dtype=tf.float32)
out_tensor = StopGradient()(in_tensor)
assert np.array_equal(in_tensor.eval(), out_tensor.eval())

def test_add(self):
"""Test that Add can be invoked."""
value1 = np.random.uniform(size=(2, 3)).astype(np.float32)
Expand Down
12 changes: 11 additions & 1 deletion deepchem/models/tensorgraph/tests/test_layers_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
DTNNExtract, DAGLayer, DAGGather, MessagePassing, SetGather
from deepchem.models.tensorgraph.layers import Feature, Conv1D, Dense, Flatten, Reshape, Squeeze, Transpose, \
CombineMeanStd, Repeat, Gather, GRU, L2Loss, Concat, SoftMax, \
Constant, Variable, Add, Multiply, Log, Exp, InteratomicL2Distances, \
Constant, Variable, StopGradient, Add, Multiply, Log, Exp, InteratomicL2Distances, \
SoftMaxCrossEntropy, ReduceMean, ToFloat, ReduceSquareDifference, Conv2D, MaxPool2D, ReduceSum, GraphConv, GraphPool, \
GraphGather, BatchNorm, WeightedError, \
Conv3D, MaxPool3D, \
Expand Down Expand Up @@ -167,6 +167,16 @@ def test_Variable_pickle():
tg.save()


def test_StopGradient_pickle():
tg = TensorGraph()
feature = Feature(shape=(tg.batch_size, 1))
output = StopGradient(feature)
tg.add_output(output)
tg.set_loss(output)
tg.build()
tg.save()


def test_Log_pickle():
tg = TensorGraph()
feature = Feature(shape=(tg.batch_size, 1))
Expand Down
40 changes: 38 additions & 2 deletions deepchem/models/tensorgraph/tests/test_tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import deepchem as dc
from deepchem.data import NumpyDataset
from deepchem.data.datasets import Databag
from deepchem.models.tensorgraph.layers import Dense, SoftMaxCrossEntropy, ReduceMean, SoftMax, Constant
from deepchem.models.tensorgraph.layers import Dense, SoftMaxCrossEntropy, ReduceMean, SoftMax, Constant, Variable
from deepchem.models.tensorgraph.layers import Feature, Label
from deepchem.models.tensorgraph.layers import ReduceSquareDifference
from deepchem.models.tensorgraph.layers import ReduceSquareDifference, Add
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
from deepchem.models.tensorgraph.optimizers import GradientDescent, ExponentialDecay

Expand Down Expand Up @@ -314,3 +314,39 @@ def test_operators(self):
for o, e in zip(tg.outputs, expected):
value = tg.predict_on_batch(np.array([0]), outputs=o)
assert np.array_equal(e, value)

def test_initialize_variable(self):
"""Test methods for initializing a variable."""
tg = dc.models.TensorGraph(use_queue=False)
features = Feature(shape=(None, 1))
tg.set_loss(Dense(1, in_layers=features))
var = Variable([10.0])
tg.add_output(var)
tg.fit_generator([])
assert tg.predict_on_batch(np.zeros((1, 1))) == [10.0]
var.set_variable_initial_values([[15.0]])
tg.fit_generator([])
assert tg.predict_on_batch(np.zeros((1, 1))) == [15.0]

def test_copy_layers(self):
"""Test copying layers."""
tg = dc.models.TensorGraph()
features = Feature(shape=(None, 10))
dense = Dense(
10, in_layers=features, biases_initializer=tf.random_normal_initializer)
constant = Constant(10.0)
output = dense + constant
tg.add_output(output)
tg.set_loss(output)
tg.fit_generator([])
replacements = {constant: Constant(20.0)}
copy = output.copy(replacements, tg)
assert isinstance(copy, Add)
assert isinstance(copy.in_layers[0], Dense)
assert isinstance(copy.in_layers[0].in_layers[0], Feature)
assert copy.in_layers[1] == replacements[constant]
variables = tg.get_layer_variables(dense)
with tg._get_tf("Graph").as_default():
values = tg.session.run(variables)
for v1, v2 in zip(values, copy.in_layers[0].variable_values):
assert np.array_equal(v1, v2)

0 comments on commit a260b0a

Please sign in to comment.