From 1656bad1d1a4d2361aec3fba0e437e86672f10ca Mon Sep 17 00:00:00 2001 From: peastman Date: Mon, 25 Sep 2017 17:01:34 -0700 Subject: [PATCH 1/3] Copy layers, initialize variables --- deepchem/models/tensorgraph/layers.py | 85 ++++++++++++++++--- deepchem/models/tensorgraph/tensor_graph.py | 19 +++-- .../tensorgraph/tests/test_tensor_graph.py | 40 ++++++++- 3 files changed, 123 insertions(+), 21 deletions(-) diff --git a/deepchem/models/tensorgraph/layers.py b/deepchem/models/tensorgraph/layers.py index 35306a867a..3bbec0589f 100644 --- a/deepchem/models/tensorgraph/layers.py +++ b/deepchem/models/tensorgraph/layers.py @@ -1,6 +1,7 @@ import random import string from collections import Sequence +from copy import deepcopy import tensorflow as tf import numpy as np @@ -23,6 +24,7 @@ def __init__(self, in_layers=None, **kwargs): self.in_layers = in_layers self.op_type = "gpu" self.variable_scope = '' + self.variable_values = None self.rnn_initial_states = [] self.rnn_final_states = [] self.rnn_zero_states = [] @@ -115,9 +117,22 @@ def _record_variable_scope(self, local_scope): else: self.variable_scope = local_scope + def set_variable_initial_values(self, values): + """Set the initial values of all variables. + + This takes a list, which contains the initial values to use for all of + this layer's values (in the same order retured by + TensorGraph.get_layer_variables()). When this layer is used in a + TensorGraph, it will automatically initialize each variable to the value + specified in the list. Note that some layers also have separate mechanisms + for specifying variable initializers; this method overrides them. The + purpose of this method is to let a Layer object represent a pre-trained + layer, complete with trained values for its variables.""" + self.variable_values = values + def set_summary(self, summary_op, summary_description=None, collections=None): """Annotates a tensor with a tf.summary operation - Collects data from self.out_tensor by default but can be changed by setting + Collects data from self.out_tensor by default but can be changed by setting self.tb_input to another tensor in create_tensor @@ -156,6 +171,50 @@ def add_summary_to_tg(self): elif self.summary_op == 'histogram': tf.summary.histogram(self.name, self.tb_input, self.collections) + def copy(self, replacements={}, variables_graph=None): + """Duplicate this Layer and all its inputs. + + This creates and returns a clone of this layer. It also recursively calls + copy() on all of this layer's inputs to clone the entire hierarchy of layers. + In the process, you can optionally tell it to replace particular layers with + specific existing ones. For example, you can clone a stack of layers, while + connecting the topmost ones to different inputs. + + Parameters + ---------- + replacements: map + specifies existing layers, and the layers to replace them with (instead of + cloning them). This argument serves two purposes. First, you can pass in + a list of replacements to control which layers get cloned. In addition, + as each layer is cloned, it is added to this map. On exit, it therefore + contains a complete record of all layers that were copied, and a reference + to the copy of each one. + variables_graph: TensorGraph + an optional TensorGraph from which to take variables. If this is specified, + the current value of each variable in each layer is recorded, and the copy + has that value specified as its initial value. This allows a piece of a + pre-trained model to be copied to another model. + """ + if self in replacements: + return replacements[self] + copied_inputs = [ + layer.copy(replacements, variables_graph) for layer in self.in_layers + ] + saved_inputs = self.in_layers + self.in_layers = [] + saved_tensors = self.none_tensors() + copy = deepcopy(self) + self.in_layers = saved_inputs + self.set_tensors(saved_tensors) + copy.in_layers = copied_inputs + if variables_graph is not None: + variables = variables_graph.get_layer_variables(self) + if len(variables) > 0: + with variables_graph._get_tf("Graph").as_default(): + values = variables_graph.session.run(variables) + copy.set_variable_initial_values(values) + return copy + def _as_graph_element(self): if '_as_graph_element' in dir(self.out_tensor): return self.out_tensor._as_graph_element() @@ -1724,14 +1783,14 @@ def __init__(self, Dimensionality of output vectors. input_dim: int Dimensionality of input vectors. - init_fn: object - TensorFlow initialization to use for W. - inner_init_fn: object - TensorFlow initialization to use for U. - activation_fn: object - TensorFlow activation to use for output. - inner_activation_fn: object - TensorFlow activation to use for inner steps. + init_fn: object + TensorFlow initialization to use for W. + inner_init_fn: object + TensorFlow initialization to use for U. + activation_fn: object + TensorFlow activation to use for output. + inner_activation_fn: object + TensorFlow activation to use for inner steps. """ super(LSTMStep, self).__init__(**kwargs) @@ -1787,7 +1846,7 @@ def create_tensor(self, in_layers=None, set_tensors=True, **kwargs): Returns ------- list - Returns h, [h + c] + Returns h, [h + c] """ activation = self.activation inner_activation = self.inner_activation @@ -1826,7 +1885,7 @@ def _cosine_dist(x, y): x: tf.Tensor Input Tensor y: tf.Tensor - Input Tensor + Input Tensor """ denom = ( model_ops.sqrt(model_ops.sum(tf.square(x)) * model_ops.sum(tf.square(y))) @@ -2911,7 +2970,7 @@ def distance_matrix(self, D): def AlphaShare(in_layers=None, **kwargs): """ This method should be used when constructing AlphaShare layers from Sluice Networks - + Parameters ---------- in_layers: list of Layers or tensors @@ -2950,7 +3009,7 @@ class AlphaShareLayer(Layer): Returns ------- out_tensor: a tensor with shape [len(in_layers), x, y] where x, y were the original layer dimensions - out_tensor should be fed into LayerSplitter + out_tensor should be fed into LayerSplitter Distance matrix. """ diff --git a/deepchem/models/tensorgraph/tensor_graph.py b/deepchem/models/tensorgraph/tensor_graph.py index a9598b94df..79106aa1c4 100644 --- a/deepchem/models/tensorgraph/tensor_graph.py +++ b/deepchem/models/tensorgraph/tensor_graph.py @@ -180,6 +180,13 @@ def create_feed_dict(): self.session.run(tf.global_variables_initializer()) if restore: self.restore() + else: + # Initialize variables that have pre-trained values. + for layer in self.layers.values(): + if layer.variable_values is not None: + variables = self.get_layer_variables(layer) + for var, val in zip(variables, layer.variable_values): + self.session.run(var.assign(val)) avg_loss, n_batches = 0.0, 0.0 coord = tf.train.Coordinator() n_samples = 0 @@ -330,11 +337,11 @@ def predict_on_batch(self, X, transformers=[], outputs=None): """Generates predictions for input samples, processing samples in a batch. Parameters - ---------- + ---------- X: ndarray the input data, as a Numpy array. transformers: List - List of dc.trans.Transformers + List of dc.trans.Transformers Returns ------- @@ -348,11 +355,11 @@ def predict_proba_on_batch(self, X, transformers=[], outputs=None): """Generates predictions for input samples, processing samples in a batch. Parameters - ---------- + ---------- X: ndarray the input data, as a Numpy array. transformers: List - List of dc.trans.Transformers + List of dc.trans.Transformers Returns ------- @@ -370,7 +377,7 @@ def predict(self, dataset, transformers=[], outputs=None): Dataset to make prediction on transformers: list List of dc.trans.Transformers. - outputs: object + outputs: object If outputs is None, then will assume outputs = self.outputs[0] (single output). If outputs is a Layer/Tensor, then will evaluate and return as a single ndarray. If outputs is a list of Layers/Tensors, will return a list @@ -391,7 +398,7 @@ def predict_proba(self, dataset, transformers=[], outputs=None): Dataset to make prediction on transformers: list List of dc.trans.Transformers. - outputs: object + outputs: object If outputs is None, then will assume outputs = self.outputs[0] (single output). If outputs is a Layer/Tensor, then will evaluate and return as a single ndarray. If outputs is a list of Layers/Tensors, will return a list diff --git a/deepchem/models/tensorgraph/tests/test_tensor_graph.py b/deepchem/models/tensorgraph/tests/test_tensor_graph.py index fc0f269dfb..f52e5601fc 100644 --- a/deepchem/models/tensorgraph/tests/test_tensor_graph.py +++ b/deepchem/models/tensorgraph/tests/test_tensor_graph.py @@ -9,9 +9,9 @@ import deepchem as dc from deepchem.data import NumpyDataset from deepchem.data.datasets import Databag -from deepchem.models.tensorgraph.layers import Dense, SoftMaxCrossEntropy, ReduceMean, SoftMax, Constant +from deepchem.models.tensorgraph.layers import Dense, SoftMaxCrossEntropy, ReduceMean, SoftMax, Constant, Variable from deepchem.models.tensorgraph.layers import Feature, Label -from deepchem.models.tensorgraph.layers import ReduceSquareDifference +from deepchem.models.tensorgraph.layers import ReduceSquareDifference, Add from deepchem.models.tensorgraph.tensor_graph import TensorGraph from deepchem.models.tensorgraph.optimizers import GradientDescent, ExponentialDecay @@ -314,3 +314,39 @@ def test_operators(self): for o, e in zip(tg.outputs, expected): value = tg.predict_on_batch(np.array([0]), outputs=o) assert np.array_equal(e, value) + + def test_initialize_variable(self): + """Test methods for initializing a variable.""" + tg = dc.models.TensorGraph(use_queue=False) + features = Feature(shape=(None, 1)) + tg.set_loss(Dense(1, in_layers=features)) + var = Variable([10.0]) + tg.add_output(var) + tg.fit_generator([]) + assert tg.predict_on_batch(np.zeros((1, 1))) == [10.0] + var.set_variable_initial_values([[15.0]]) + tg.fit_generator([]) + assert tg.predict_on_batch(np.zeros((1, 1))) == [15.0] + + def test_copy_layers(self): + """Test copying layers.""" + tg = dc.models.TensorGraph() + features = Feature(shape=(None, 10)) + dense = Dense( + 10, in_layers=features, biases_initializer=tf.random_normal_initializer) + constant = Constant(10.0) + output = dense + constant + tg.add_output(output) + tg.set_loss(output) + tg.fit_generator([]) + replacements = {constant: Constant(20.0)} + copy = output.copy(replacements, tg) + assert isinstance(copy, Add) + assert isinstance(copy.in_layers[0], Dense) + assert isinstance(copy.in_layers[0].in_layers[0], Feature) + assert copy.in_layers[1] == replacements[constant] + variables = tg.get_layer_variables(dense) + with tg._get_tf("Graph").as_default(): + values = tg.session.run(variables) + for v1, v2 in zip(values, copy.in_layers[0].variable_values): + assert np.array_equal(v1, v2) From 322ef3b96a7dd2cb9bdc7e139c457454671659fd Mon Sep 17 00:00:00 2001 From: peastman Date: Mon, 25 Sep 2017 17:22:49 -0700 Subject: [PATCH 2/3] Implemented StopGradient --- deepchem/models/tensorgraph/layers.py | 24 +++++++++++++++++++ .../models/tensorgraph/tests/test_layers.py | 11 +++++++++ .../tensorgraph/tests/test_layers_pickle.py | 12 +++++++++- 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/deepchem/models/tensorgraph/layers.py b/deepchem/models/tensorgraph/layers.py index 3bbec0589f..ca86960d8f 100644 --- a/deepchem/models/tensorgraph/layers.py +++ b/deepchem/models/tensorgraph/layers.py @@ -977,6 +977,30 @@ def create_tensor(self, in_layers=None, set_tensors=True, **kwargs): return out_tensor +class StopGradient(Layer): + """Block the flow of gradients. + + This layer copies its input directly to its output, but reports that all + gradients of its output are zero. This means, for example, that optimizers + will not try to optimize anything "upstream" of this layer.""" + + def __init__(self, in_layers=None, **kwargs): + super(StopGradient, self).__init__(in_layers, **kwargs) + try: + self._shape = tuple(self.in_layers[0].shape) + except: + pass + + def create_tensor(self, in_layers=None, set_tensors=True, **kwargs): + inputs = self._get_input_tensors(in_layers) + if len(inputs) > 1: + raise ValueError("Only one layer supported.") + out_tensor = tf.stop_gradient(inputs[0]) + if set_tensors: + self.out_tensor = out_tensor + return out_tensor + + def _max_dimension(x, y): if x is None: return y diff --git a/deepchem/models/tensorgraph/tests/test_layers.py b/deepchem/models/tensorgraph/tests/test_layers.py index f19abbfeb2..c35d5aaa89 100644 --- a/deepchem/models/tensorgraph/tests/test_layers.py +++ b/deepchem/models/tensorgraph/tests/test_layers.py @@ -39,6 +39,7 @@ from deepchem.models.tensorgraph.layers import SluiceLoss from deepchem.models.tensorgraph.layers import SoftMax from deepchem.models.tensorgraph.layers import SoftMaxCrossEntropy +from deepchem.models.tensorgraph.layers import StopGradient from deepchem.models.tensorgraph.layers import TensorWrapper from deepchem.models.tensorgraph.layers import TimeSeriesDense from deepchem.models.tensorgraph.layers import ToFloat @@ -241,6 +242,16 @@ def test_variable(self): sess.run(tf.global_variables_initializer()) assert np.array_equal(value, out_tensor.eval()) + def test_stop_gradient(self): + """Test that StopGradient can be invoked.""" + batch_size = 10 + n_features = 5 + in_tensor = np.random.rand(batch_size, n_features) + with self.test_session() as sess: + in_tensor = tf.convert_to_tensor(in_tensor, dtype=tf.float32) + out_tensor = StopGradient()(in_tensor) + assert np.array_equal(in_tensor.eval(), out_tensor.eval()) + def test_add(self): """Test that Add can be invoked.""" value1 = np.random.uniform(size=(2, 3)).astype(np.float32) diff --git a/deepchem/models/tensorgraph/tests/test_layers_pickle.py b/deepchem/models/tensorgraph/tests/test_layers_pickle.py index a199f8353a..6236e0bdc2 100644 --- a/deepchem/models/tensorgraph/tests/test_layers_pickle.py +++ b/deepchem/models/tensorgraph/tests/test_layers_pickle.py @@ -7,7 +7,7 @@ DTNNExtract, DAGLayer, DAGGather, MessagePassing, SetGather from deepchem.models.tensorgraph.layers import Feature, Conv1D, Dense, Flatten, Reshape, Squeeze, Transpose, \ CombineMeanStd, Repeat, Gather, GRU, L2Loss, Concat, SoftMax, \ - Constant, Variable, Add, Multiply, Log, Exp, InteratomicL2Distances, \ + Constant, Variable, StopGradient, Add, Multiply, Log, Exp, InteratomicL2Distances, \ SoftMaxCrossEntropy, ReduceMean, ToFloat, ReduceSquareDifference, Conv2D, MaxPool2D, ReduceSum, GraphConv, GraphPool, \ GraphGather, BatchNorm, WeightedError, \ Conv3D, MaxPool3D, \ @@ -167,6 +167,16 @@ def test_Variable_pickle(): tg.save() +def test_StopGradient_pickle(): + tg = TensorGraph() + feature = Feature(shape=(tg.batch_size, 1)) + output = StopGradient(feature) + tg.add_output(output) + tg.set_loss(output) + tg.build() + tg.save() + + def test_Log_pickle(): tg = TensorGraph() feature = Feature(shape=(tg.batch_size, 1)) From fe274ab090148414328ca4943ff3929fff84156b Mon Sep 17 00:00:00 2001 From: peastman Date: Tue, 26 Sep 2017 10:06:31 -0700 Subject: [PATCH 3/3] Added more documentation --- deepchem/models/tensorgraph/layers.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/deepchem/models/tensorgraph/layers.py b/deepchem/models/tensorgraph/layers.py index ca86960d8f..bb8a466b52 100644 --- a/deepchem/models/tensorgraph/layers.py +++ b/deepchem/models/tensorgraph/layers.py @@ -180,6 +180,20 @@ def copy(self, replacements={}, variables_graph=None): specific existing ones. For example, you can clone a stack of layers, while connecting the topmost ones to different inputs. + For example, consider a stack of dense layers that depend on an input: + + >>> input = Feature(shape=(None, 100)) + >>> dense1 = Dense(100, in_layers=input) + >>> dense2 = Dense(100, in_layers=dense1) + >>> dense3 = Dense(100, in_layers=dense2) + + The following will clone all three dense layers, but not the input layer. + Instead, the input to the first dense layer will be a different layer + specified in the replacements map. + + >>> replacements = {input: new_input} + >>> dense3_copy = dense3.copy(replacements) + Parameters ---------- replacements: map @@ -982,7 +996,13 @@ class StopGradient(Layer): This layer copies its input directly to its output, but reports that all gradients of its output are zero. This means, for example, that optimizers - will not try to optimize anything "upstream" of this layer.""" + will not try to optimize anything "upstream" of this layer. + + For example, suppose you have pre-trained a stack of layers to perform a + calculation. You want to use the result of that calculation as the input to + another layer, but because they are already pre-trained, you do not want the + optimizer to modify them. You can wrap the output in a StopGradient layer, + then use that as the input to the next layer.""" def __init__(self, in_layers=None, **kwargs): super(StopGradient, self).__init__(in_layers, **kwargs)