From 144855b38555ec83527fd5dd51ffd6db5f82715c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 May 2016 05:55:50 -0800 Subject: [PATCH] Added model_variable and helpers to manage variables. Change: 122727793 --- tensorflow/contrib/framework/__init__.py | 31 +- .../framework/python/framework/tensor_util.py | 30 +- .../contrib/framework/python/ops/arg_scope.py | 11 +- .../contrib/framework/python/ops/variables.py | 322 +++++++++- .../framework/python/ops/variables_test.py | 566 +++++++++++++++++- 5 files changed, 906 insertions(+), 54 deletions(-) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 242218209eacca..fe039f20138516 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -18,13 +18,40 @@ @@assert_same_float_dtype @@assert_scalar_int @@convert_to_tensor_or_sparse_tensor -@@local_variable +@@get_graph_from_inputs +@@is_numeric_tensor +@@is_non_decreasing +@@is_strictly_increasing @@reduce_sum_n @@safe_embedding_lookup_sparse @@with_shape @@with_same_shape -@@get_graph_from_inputs + +## Arg_Scope +@@arg_scope +@@add_arg_scope +@@has_arg_scope +@@arg_scoped_arguments + +## Variables +@@add_model_variable +@@assert_global_step +@@assert_or_get_global_step +@@create_global_step +@@get_global_step +@@get_or_create_global_step +@@get_local_variables +@@get_model_variables +@@get_unique_variable +@@get_variables_by_name +@@get_variables_by_suffix +@@get_variables_to_restore +@@get_variables +@@local_variable +@@model_variable +@@variable +@@VariableDeviceChooser """ from __future__ import absolute_import diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py index febcae883b2ffa..3e8da8cf4e82b5 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util.py @@ -14,7 +14,6 @@ # ============================================================================== """Tensor utility functions.""" - from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -27,14 +26,16 @@ from tensorflow.python.ops import variables __all__ = [ - 'assert_same_float_dtype', 'assert_scalar_int', - 'convert_to_tensor_or_sparse_tensor', 'local_variable', 'reduce_sum_n', - 'with_shape', 'with_same_shape', -] + 'assert_same_float_dtype', + 'assert_scalar_int', + 'convert_to_tensor_or_sparse_tensor', + 'reduce_sum_n', + 'with_shape', + 'with_same_shape'] def _assert_same_base_type(items, expected_type=None): - """Asserts all items are of the same base type. + r"""Asserts all items are of the same base type. Args: items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, @@ -110,23 +111,6 @@ def assert_scalar_int(tensor): return tensor -# TODO(ptucker): Move to tf.variables? -def local_variable(initial_value, validate_shape=True, name=None): - """Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection. - - Args: - initial_value: See variables.Variable.__init__. - validate_shape: See variables.Variable.__init__. - name: See variables.Variable.__init__. - Returns: - New variable. - """ - return variables.Variable( - initial_value, trainable=False, - collections=[ops.GraphKeys.LOCAL_VARIABLES], - validate_shape=validate_shape, name=name) - - def reduce_sum_n(tensors, name=None): """Reduce tensors to a scalar sum. diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index aa6acd735a0bb1..587b7739c99d8d 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -47,11 +47,6 @@ @tf.contrib.add_arg_scope def conv2d(*args, **kwargs) - -@@arg_scope -@@add_arg_scope -@@has_arg_scope -@@arg_scoped_arguments """ from __future__ import absolute_import from __future__ import division @@ -59,8 +54,10 @@ def conv2d(*args, **kwargs) import contextlib import functools -__all__ = ['arg_scope', 'add_arg_scope', - 'has_arg_scope', 'arg_scoped_arguments'] +__all__ = ['arg_scope', + 'add_arg_scope', + 'has_arg_scope', + 'arg_scoped_arguments'] _ARGSTACK = [{}] diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index b040e3e9d1e806..43f1dd2944a1bc 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -14,24 +14,37 @@ # ============================================================================== """Variable functions. - -@@assert_global_step -@@create_global_step -@@get_global_step -@@assert_or_get_global_step -@@local_variable """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope +from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging -__all__ = [ - 'assert_global_step', 'create_global_step', 'get_global_step', - 'assert_or_get_global_step', 'local_variable'] + +__all__ = ['add_model_variable', + 'assert_global_step', + 'assert_or_get_global_step', + 'create_global_step', + 'get_global_step', + 'get_or_create_global_step', + 'get_local_variables', + 'get_model_variables', + 'get_unique_variable', + 'get_variables_by_name', + 'get_variables_by_suffix', + 'get_variables_to_restore', + 'get_variables', + 'local_variable', + 'model_variable', + 'variable', + 'VariableDeviceChooser'] def assert_global_step(global_step_tensor): @@ -125,7 +138,6 @@ def create_global_step(graph=None): Global step tensor. Raises: - TypeError: if `dtype` is invalid. ValueError: if global step key is already defined. """ graph = ops.get_default_graph() if graph is None else graph @@ -133,10 +145,27 @@ def create_global_step(graph=None): raise ValueError('"global_step" already exists.') # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): - result = variables.Variable( - 0, trainable=False, dtype=dtypes.int64, name=ops.GraphKeys.GLOBAL_STEP) - graph.add_to_collection(ops.GraphKeys.GLOBAL_STEP, result) - return result + collections = [ops.GraphKeys.VARIABLES, ops.GraphKeys.GLOBAL_STEP] + return variable(ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64, + initializer=init_ops.zeros_initializer, trainable=False, + collections=collections) + + +def get_or_create_global_step(graph=None): + """Returns and create (if necessary) the global step variable. + + Args: + graph: The graph in which to create the global step. If missing, use default + graph. + + Returns: + the tensor representing the global step variable. + """ + graph = ops.get_default_graph() if graph is None else graph + globalstep = get_global_step(graph) + if globalstep is None: + globalstep = create_global_step(graph) + return globalstep def local_variable(initial_value, validate_shape=True, name=None): @@ -154,3 +183,268 @@ def local_variable(initial_value, validate_shape=True, name=None): collections=[ops.GraphKeys.LOCAL_VARIABLES], validate_shape=validate_shape, name=name) + +@contrib_add_arg_scope +def variable(name, shape=None, dtype=dtypes.float32, initializer=None, + regularizer=None, trainable=True, collections=None, + caching_device=None, device=None): + """Gets an existing variable with these parameters or creates a new one. + + Args: + name: the name of the new or existing variable. + shape: shape of the new or existing variable. + dtype: type of the new or existing variable (defaults to `DT_FLOAT`). + initializer: initializer for the variable if one is created. + regularizer: a (Tensor -> Tensor or None) function; the result of + applying it on a newly created variable will be added to the collection + GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + collections: A list of collection names to which the Variable will be added. + If None it would default to tf.GraphKeys.VARIABLES. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. + device: Optional device to place the variable. It can be an string or a + function that is called to get the device for the variable. + + Returns: + The created or existing variable. + """ + collections = list(collections or [ops.GraphKeys.VARIABLES]) + + # Remove duplicates + collections = set(collections) + with ops.device(device or ''): + return variable_scope.get_variable(name, shape=shape, dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=trainable, + collections=collections, + caching_device=caching_device) + +# TODO(sguada) move it to ops.GraphKeys or to contrib.framework.GraphKeys +# Collection containing all the variables created using model_variables. +MODEL_VARIABLES = '_model_variables_' + + +@contrib_add_arg_scope +def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None, + regularizer=None, trainable=True, collections=None, + caching_device=None, device=None): + """Gets an existing model variable with these parameters or creates a new one. + + Args: + name: the name of the new or existing variable. + shape: shape of the new or existing variable. + dtype: type of the new or existing variable (defaults to `DT_FLOAT`). + initializer: initializer for the variable if one is created. + regularizer: a (Tensor -> Tensor or None) function; the result of + applying it on a newly created variable will be added to the collection + GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + collections: A list of collection names to which the Variable will be added. + Note that the variable is always also added to the tf.GraphKeys.VARIABLES + and MODEL_VARIABLES collections. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. + device: Optional device to place the variable. It can be an string or a + function that is called to get the device for the variable. + + Returns: + The created or existing variable. + """ + collections = list(collections or []) + + # Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES + collections += [ops.GraphKeys.VARIABLES, MODEL_VARIABLES] + return variable(name, shape=shape, dtype=dtype, + initializer=initializer, regularizer=regularizer, + trainable=trainable, collections=collections, + caching_device=caching_device, device=device) + + +def add_model_variable(var): + """Adds a variable to the MODEL_VARIABLES collection. + + Args: + var: a variable. + """ + if var not in ops.get_collection(MODEL_VARIABLES): + ops.add_to_collection(MODEL_VARIABLES, var) + + +def get_variables(scope=None, suffix=None, collection=ops.GraphKeys.VARIABLES): + """Gets the list of variables, filtered by scope and/or suffix. + + Args: + scope: an optional scope for filtering the variables to return. + suffix: an optional suffix for filtering the variables to return. + collection: in which collection search for. Defaults to GraphKeys.VARIABLES. + + Returns: + a list of variables in colelction with scope and suffix. + """ + if suffix is not None: + if ':' not in suffix: + suffix += ':' + scope = (scope or '') + '.*' + suffix + return ops.get_collection(collection, scope) + + +def get_model_variables(scope=None, suffix=None): + """Gets the list of model variables, filtered by scope and/or suffix. + + Args: + scope: an optional scope for filtering the variables to return. + suffix: an optional suffix for filtering the variables to return. + + Returns: + a list of variables in colelction with scope and suffix. + """ + return get_variables(scope, suffix, MODEL_VARIABLES) + + +def get_local_variables(scope=None, suffix=None): + """Gets the list of model variables, filtered by scope and/or suffix. + + Args: + scope: an optional scope for filtering the variables to return. + suffix: an optional suffix for filtering the variables to return. + + Returns: + a list of variables in colelction with scope and suffix. + """ + return get_variables(scope, suffix, ops.GraphKeys.LOCAL_VARIABLES) + + +def get_variables_to_restore(include=None, exclude=None): + """Gets the list of the variables to restore. + + Args: + include: an optional list/tuple of scope strings for filtering which + variables from the VARIABLES collection to include. None would include all + the variables. + exclude: an optional list/tuple of scope strings for filtering which + variables from the VARIABLES collection to exclude. None it would not + exclude any. + + Returns: + a list of variables to restore. + + Raises: + TypeError: include or exclude is provided but is not a list or a tuple. + """ + if include is None: + # Include all variables. + vars_to_include = get_variables() + else: + if not isinstance(include, (list, tuple)): + raise TypeError('include is provided but is not a list or a tuple.') + vars_to_include = [] + for scope in include: + vars_to_include += get_variables(scope) + vars_to_exclude = set() + if exclude is not None: + if not isinstance(exclude, (list, tuple)): + raise TypeError('exclude is provided but is not a list or a tuple.') + for scope in exclude: + vars_to_exclude |= set(get_variables(scope)) + # Exclude the variables in vars_to_exclude + return [v for v in vars_to_include if v not in vars_to_exclude] + + +def get_variables_by_suffix(suffix, scope=None): + """Gets the list of variables that end with the given suffix. + + Args: + suffix: suffix for filtering the variables to return. + scope: an optional scope for filtering the variables to return. + + Returns: + a copied list of variables with the given name and prefix. + """ + return get_variables(scope=scope, suffix=suffix) + + +def get_variables_by_name(given_name, scope=None): + """Gets the list of variables that were given that name. + + Args: + given_name: name given to the variable without any scope. + scope: an optional scope for filtering the variables to return. + + Returns: + a copied list of variables with the given name and scope. + """ + suffix = '/' + given_name + ':|^' + given_name + ':' + return get_variables(scope=scope, suffix=suffix) + + +def get_unique_variable(var_op_name): + """Gets the variable uniquely identified by that var_op_name. + + Args: + var_op_name: the full name of the variable op, including the scope. + + Returns: + a tensorflow variable. + + Raises: + ValueError: if no variable uniquely identified by the name exists. + """ + candidates = get_variables(scope=var_op_name) + if not candidates: + raise ValueError('Couldnt find variable %s' % var_op_name) + + for candidate in candidates: + if candidate.op.name == var_op_name: + return candidate + raise ValueError('Variable %s does not uniquely identify a variable', + var_op_name) + + +class VariableDeviceChooser(object): + """Device chooser for variables. + + When using a parameter server it will assign them in a round-robin fashion. + When not using a parameter server it allows GPU or CPU placement. + """ + + def __init__(self, + num_tasks=0, + device_type='CPU', + device_index=0): + """Initialize VariableDeviceChooser. + + Usage: + To use with 2 parameter servers: + VariableDeviceChooser(2) + + To use without parameter servers: + VariableDeviceChooser() + VariableDeviceChooser(device_type='GPU') # For GPU placement + + Args: + num_tasks: number of tasks. + device_type: Optional device type string (e.g. "CPU" or "GPU") + device_index: int. Optional device index. If left + unspecified, device represents 'any' device_index. + """ + self._job_name = 'ps' if num_tasks > 0 else None + self._device_type = device_type + self._device_index = device_index + self._num_tasks = num_tasks + self._next_task_id = 0 + + def __call__(self, op): + device_spec = tf_device.DeviceSpec(job=self._job_name, + device_type=self._device_type, + device_index=self._device_index) + if self._num_tasks > 0: + task_id = self._next_task_id + self._next_task_id = (self._next_task_id + 1) % self._num_tasks + device_spec.task = task_id + return device_spec.to_string() diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index af2e36c9ec4938..0f42d1ea59ec97 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -37,11 +37,57 @@ def test_local_variable(self): tf.initialize_variables(variables).run() self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) + def testLocalVariableNameAndShape(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.local_variable([1, 1, 1, 1, 1], name='a') + self.assertEquals(a.op.name, 'A/a') + self.assertListEqual(a.get_shape().as_list(), [5]) + self.assertListEqual([a], tf.contrib.framework.get_local_variables()) + + def testLocalVariableNotInAllVariables(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.local_variable(0) + self.assertFalse(a in tf.all_variables()) + self.assertTrue(a in tf.local_variables()) + + def testLocalVariableNotInVariablesToRestore(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.local_variable(0) + self.assertFalse(a in tf.contrib.framework.get_variables_to_restore()) + self.assertTrue(a in tf.local_variables()) + + def testGetVariablesDontReturnsTransients(self): + with self.test_session(): + with tf.variable_scope('A'): + tf.contrib.framework.local_variable(0) + with tf.variable_scope('B'): + tf.contrib.framework.local_variable(0) + self.assertEquals([], tf.contrib.framework.get_variables('A')) + self.assertEquals([], tf.contrib.framework.get_variables('B')) + + def testGetLocalVariablesReturnsTransients(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.local_variable(0) + with tf.variable_scope('B'): + b = tf.contrib.framework.local_variable(0) + self.assertEquals([a], tf.contrib.framework.get_local_variables('A')) + self.assertEquals([b], tf.contrib.framework.get_local_variables('B')) + + def testInitializedVariableValue(self): + with self.test_session() as sess: + a = tf.contrib.framework.local_variable([0, 0, 0, 0, 0], name='a') + sess.run(tf.initialize_local_variables()) + self.assertAllEqual(a.eval(), [0]*5) + class GlobalStepTest(tf.test.TestCase): def _assert_global_step(self, global_step, expected_dtype=tf.int64): - self.assertEquals("%s:0" % tf.GraphKeys.GLOBAL_STEP, global_step.name) + self.assertEquals('%s:0' % tf.GraphKeys.GLOBAL_STEP, global_step.name) self.assertEquals(expected_dtype, global_step.dtype.base_dtype) self.assertEquals([], global_step.get_shape().as_list()) @@ -51,10 +97,10 @@ def test_invalid_dtype(self): tf.Variable( 0.0, trainable=False, dtype=tf.float32, name=tf.GraphKeys.GLOBAL_STEP) self.assertRaisesRegexp( - TypeError, "does not have integer type", + TypeError, 'does not have integer type', tf.contrib.framework.get_global_step) self.assertRaisesRegexp( - TypeError, "does not have integer type", + TypeError, 'does not have integer type', tf.contrib.framework.get_global_step, g) def test_invalid_shape(self): @@ -63,10 +109,10 @@ def test_invalid_shape(self): tf.Variable( [0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP) self.assertRaisesRegexp( - TypeError, "not scalar", + TypeError, 'not scalar', tf.contrib.framework.get_global_step) self.assertRaisesRegexp( - TypeError, "not scalar", + TypeError, 'not scalar', tf.contrib.framework.get_global_step, g) def test_create_global_step(self): @@ -75,9 +121,9 @@ def test_create_global_step(self): global_step = tf.contrib.framework.create_global_step() self._assert_global_step(global_step) self.assertRaisesRegexp( - ValueError, "already exists", tf.contrib.framework.create_global_step) + ValueError, 'already exists', tf.contrib.framework.create_global_step) self.assertRaisesRegexp( - ValueError, "already exists", tf.contrib.framework.create_global_step, + ValueError, 'already exists', tf.contrib.framework.create_global_step, g) self._assert_global_step( tf.contrib.framework.create_global_step(tf.Graph())) @@ -92,6 +138,510 @@ def test_get_global_step(self): self._assert_global_step( tf.contrib.framework.get_global_step(g), expected_dtype=tf.int32) + def test_get_or_create_global_step(self): + with tf.Graph().as_default() as g: + self.assertEquals(None, tf.contrib.framework.get_global_step()) + self._assert_global_step( + tf.contrib.framework.get_or_create_global_step()) + self._assert_global_step( + tf.contrib.framework.get_or_create_global_step(g)) + + +class VariablesTest(tf.test.TestCase): + + def testCreateVariable(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + self.assertEquals(a.op.name, 'A/a') + self.assertListEqual(a.get_shape().as_list(), [5]) + + def testGetVariables(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + with tf.variable_scope('B'): + b = tf.contrib.framework.variable('a', [5]) + self.assertEquals([a, b], tf.contrib.framework.get_variables()) + self.assertEquals([a], tf.contrib.framework.get_variables('A')) + self.assertEquals([b], tf.contrib.framework.get_variables('B')) + + def testGetVariablesSuffix(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + with tf.variable_scope('A'): + b = tf.contrib.framework.variable('b', [5]) + self.assertEquals([a], tf.contrib.framework.get_variables(suffix='a')) + self.assertEquals([b], tf.contrib.framework.get_variables(suffix='b')) + + def testGetVariableWithSingleVar(self): + with self.test_session(): + with tf.variable_scope('parent'): + a = tf.contrib.framework.variable('child', [5]) + self.assertEquals( + a, tf.contrib.framework.get_unique_variable('parent/child')) + + def testGetVariableWithDistractors(self): + with self.test_session(): + with tf.variable_scope('parent'): + a = tf.contrib.framework.variable('child', [5]) + with tf.variable_scope('child'): + tf.contrib.framework.variable('grandchild1', [7]) + tf.contrib.framework.variable('grandchild2', [9]) + self.assertEquals( + a, tf.contrib.framework.get_unique_variable('parent/child')) + + def testGetVariableThrowsExceptionWithNoMatch(self): + var_name = 'cant_find_me' + with self.test_session(): + with self.assertRaises(ValueError): + tf.contrib.framework.get_unique_variable(var_name) + + def testGetThrowsExceptionWithChildrenButNoMatch(self): + var_name = 'parent/child' + with self.test_session(): + with tf.variable_scope(var_name): + tf.contrib.framework.variable('grandchild1', [7]) + tf.contrib.framework.variable('grandchild2', [9]) + with self.assertRaises(ValueError): + tf.contrib.framework.get_unique_variable(var_name) + + def testGetVariablesToRestore(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + with tf.variable_scope('B'): + b = tf.contrib.framework.variable('a', [5]) + self.assertEquals([a, b], + tf.contrib.framework.get_variables_to_restore()) + + def testIncludeGetVariablesToRestore(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + with tf.variable_scope('B'): + b = tf.contrib.framework.variable('a', [5]) + self.assertEquals([a, b], tf.contrib.framework.get_variables()) + self.assertEquals([a], + tf.contrib.framework.get_variables_to_restore(['A'])) + + def testExcludeGetVariablesToRestore(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + with tf.variable_scope('B'): + b = tf.contrib.framework.variable('a', [5]) + self.assertEquals([a, b], tf.contrib.framework.get_variables()) + self.assertEquals([a], + tf.contrib.framework.get_variables_to_restore( + exclude=['B'])) + + def testWrongIncludeGetVariablesToRestore(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + with tf.variable_scope('B'): + b = tf.contrib.framework.variable('a', [5]) + self.assertEquals([a, b], tf.contrib.framework.get_variables()) + self.assertEquals([], + tf.contrib.framework.get_variables_to_restore(['a'])) + + def testGetMixedVariablesToRestore(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + b = tf.contrib.framework.variable('b', [5]) + with tf.variable_scope('B'): + c = tf.contrib.framework.variable('c', [5]) + d = tf.contrib.framework.variable('d', [5]) + self.assertEquals([a, b, c, d], tf.contrib.framework.get_variables()) + self.assertEquals([a, c], + tf.contrib.framework.get_variables_to_restore( + include=['A/a', 'B/c'])) + + def testExcludeGetMixedVariablesToRestore(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + b = tf.contrib.framework.variable('b', [5]) + with tf.variable_scope('B'): + c = tf.contrib.framework.variable('c', [5]) + d = tf.contrib.framework.variable('d', [5]) + self.assertEquals([a, b, c, d], tf.contrib.framework.get_variables()) + self.assertEquals([b, d], + tf.contrib.framework.get_variables_to_restore( + exclude=['A/a', 'B/c'])) + + def testReuseVariable(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', []) + with tf.variable_scope('A', reuse=True): + b = tf.contrib.framework.variable('a', []) + self.assertEquals(a, b) + self.assertListEqual([a], tf.contrib.framework.get_variables()) + + def testVariableWithRegularizer(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [], regularizer=tf.nn.l2_loss) + loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0] + self.assertDeviceEqual(loss.device, a.device) + + def testVariableWithRegularizerColocate(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [], device='gpu:0', + regularizer=tf.nn.l2_loss) + loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0] + self.assertDeviceEqual(loss.device, a.device) + + def testVariableWithDevice(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [], device='cpu:0') + b = tf.contrib.framework.variable('b', [], device='cpu:1') + self.assertDeviceEqual(a.device, 'cpu:0') + self.assertDeviceEqual(b.device, 'cpu:1') + + def testVariableWithDeviceFromScope(self): + with self.test_session(): + with tf.device('/cpu:0'): + a = tf.contrib.framework.variable('a', []) + b = tf.contrib.framework.variable('b', [], device='cpu:1') + self.assertDeviceEqual(a.device, 'cpu:0') + self.assertDeviceEqual(b.device, 'cpu:1') + + def testVariableWithDeviceFunction(self): + class DevFn(object): + + def __init__(self): + self.counter = -1 + + def __call__(self, op): + self.counter += 1 + return 'cpu:%d' % self.counter + + with self.test_session(): + with tf.contrib.framework.arg_scope([tf.contrib.framework.variable], + device=DevFn()): + a = tf.contrib.framework.variable('a', []) + b = tf.contrib.framework.variable('b', []) + c = tf.contrib.framework.variable('c', [], device='cpu:12') + d = tf.contrib.framework.variable('d', []) + with tf.device('cpu:99'): + e_init = tf.constant(12) + e = tf.contrib.framework.variable('e', initializer=e_init) + self.assertDeviceEqual(a.device, 'cpu:0') + self.assertDeviceEqual(a.initial_value.device, 'cpu:0') + self.assertDeviceEqual(b.device, 'cpu:1') + self.assertDeviceEqual(b.initial_value.device, 'cpu:1') + self.assertDeviceEqual(c.device, 'cpu:12') + self.assertDeviceEqual(c.initial_value.device, 'cpu:12') + self.assertDeviceEqual(d.device, 'cpu:2') + self.assertDeviceEqual(d.initial_value.device, 'cpu:2') + self.assertDeviceEqual(e.device, 'cpu:3') + self.assertDeviceEqual(e.initial_value.device, 'cpu:99') + + def testVariableWithReplicaDeviceSetter(self): + with self.test_session(): + with tf.device(tf.train.replica_device_setter(ps_tasks=2)): + a = tf.contrib.framework.variable('a', []) + b = tf.contrib.framework.variable('b', []) + c = tf.contrib.framework.variable('c', [], device='cpu:12') + d = tf.contrib.framework.variable('d', []) + with tf.device('cpu:99'): + e_init = tf.constant(12) + e = tf.contrib.framework.variable('e', initializer=e_init) + # The values below highlight how the replica_device_setter puts initial + # values on the worker job, and how it merges explicit devices. + self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0') + self.assertDeviceEqual(a.initial_value.device, a.device) + self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0') + self.assertDeviceEqual(b.initial_value.device, b.device) + self.assertDeviceEqual(c.device, '/job:ps/task:0/cpu:12') + self.assertDeviceEqual(c.initial_value.device, c.device) + self.assertDeviceEqual(d.device, '/job:ps/task:1/cpu:0') + self.assertDeviceEqual(d.initial_value.device, d.device) + self.assertDeviceEqual(e.device, '/job:ps/task:0/cpu:0') + self.assertDeviceEqual(e.initial_value.device, '/job:worker/cpu:99') + + def testVariableWithVariableDeviceChooser(self): + + with tf.Graph().as_default(): + device_fn = tf.contrib.framework.VariableDeviceChooser(num_tasks=2) + with tf.contrib.framework.arg_scope([tf.contrib.framework.variable], + device=device_fn): + a = tf.contrib.framework.variable('a', []) + b = tf.contrib.framework.variable('b', []) + c = tf.contrib.framework.variable('c', [], device='cpu:12') + d = tf.contrib.framework.variable('d', []) + with tf.device('cpu:99'): + e_init = tf.constant(12) + e = tf.contrib.framework.variable('e', initializer=e_init) + # The values below highlight how the VariableDeviceChooser puts initial + # values on the same device as the variable job. + self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0') + self.assertDeviceEqual(a.initial_value.device, a.device) + self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0') + self.assertDeviceEqual(b.initial_value.device, b.device) + self.assertDeviceEqual(c.device, '/cpu:12') + self.assertDeviceEqual(c.initial_value.device, c.device) + self.assertDeviceEqual(d.device, '/job:ps/task:0/cpu:0') + self.assertDeviceEqual(d.initial_value.device, d.device) + self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0') + self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + + def testVariableGPUPlacement(self): + + with tf.Graph().as_default(): + device_fn = tf.contrib.framework.VariableDeviceChooser(device_type='GPU') + with tf.contrib.framework.arg_scope([tf.contrib.framework.variable], + device=device_fn): + a = tf.contrib.framework.variable('a', []) + b = tf.contrib.framework.variable('b', []) + c = tf.contrib.framework.variable('c', [], device='cpu:12') + d = tf.contrib.framework.variable('d', []) + with tf.device('cpu:99'): + e_init = tf.constant(12) + e = tf.contrib.framework.variable('e', initializer=e_init) + # The values below highlight how the VariableDeviceChooser puts initial + # values on the same device as the variable job. + self.assertDeviceEqual(a.device, '/gpu:0') + self.assertDeviceEqual(a.initial_value.device, a.device) + self.assertDeviceEqual(b.device, '/gpu:0') + self.assertDeviceEqual(b.initial_value.device, b.device) + self.assertDeviceEqual(c.device, '/cpu:12') + self.assertDeviceEqual(c.initial_value.device, c.device) + self.assertDeviceEqual(d.device, '/gpu:0') + self.assertDeviceEqual(d.initial_value.device, d.device) + self.assertDeviceEqual(e.device, '/gpu:0') + self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + + +class ModelVariablesTest(tf.test.TestCase): + + def testNameAndShape(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.model_variable('a', [5]) + self.assertEquals(a.op.name, 'A/a') + self.assertListEqual(a.get_shape().as_list(), [5]) + self.assertListEqual([a], tf.contrib.framework.get_model_variables('A')) + + def testNotInLocalVariables(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.model_variable('a', [5]) + self.assertTrue(a in tf.all_variables()) + self.assertFalse(a in tf.local_variables()) + + def testGetVariablesReturns(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.model_variable('a', [5]) + with tf.variable_scope('B'): + b = tf.contrib.framework.model_variable('a', [5]) + self.assertEquals([a], tf.contrib.framework.get_variables('A')) + self.assertEquals([b], tf.contrib.framework.get_variables('B')) + + def testGetModelVariables(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.model_variable('a', [5]) + with tf.variable_scope('B'): + b = tf.contrib.framework.model_variable('a', [5]) + self.assertEquals([a], tf.contrib.framework.get_model_variables('A')) + self.assertEquals([b], tf.contrib.framework.get_model_variables('B')) + + def testGetLocalVariables(self): + with self.test_session(): + with tf.variable_scope('A'): + _ = tf.contrib.framework.model_variable('a', [5]) + with tf.variable_scope('B'): + _ = tf.contrib.framework.model_variable('a', [5]) + self.assertEquals([], tf.contrib.framework.get_local_variables('A')) + self.assertEquals([], tf.contrib.framework.get_local_variables('B')) + + def testInitializedVariableValue(self): + with self.test_session() as sess: + a = tf.contrib.framework.model_variable('a', [5], initializer=tf.ones) + sess.run(tf.initialize_all_variables()) + self.assertAllEqual(a.eval(), [1]*5) + + def testDeviceFn(self): + class DevFn(object): + + def __init__(self): + self.counter = -1 + + def __call__(self, op): + self.counter += 1 + return '/cpu:%d' % self.counter + + with tf.Graph().as_default(): + with tf.contrib.framework.arg_scope([tf.contrib.framework.model_variable], + device=DevFn()): + a = tf.contrib.framework.model_variable('a', [5]) + b = tf.contrib.framework.model_variable('b', [20]) + self.assertDeviceEqual(a.device, '/cpu:0') + self.assertDeviceEqual(a.initial_value.device, '/cpu:0') + self.assertDeviceEqual(b.device, '/cpu:1') + self.assertDeviceEqual(b.initial_value.device, '/cpu:1') + + def testVariableWithVariableDeviceChooser(self): + + with tf.Graph().as_default(): + device_fn = tf.contrib.framework.VariableDeviceChooser() + with tf.contrib.framework.arg_scope([tf.contrib.framework.model_variable], + device=device_fn): + a = tf.contrib.framework.model_variable('a', [5]) + b = tf.contrib.framework.model_variable('b', [20]) + self.assertDeviceEqual(a.device, 'cpu:0') + self.assertDeviceEqual(a.initial_value.device, a.device) + self.assertDeviceEqual(b.device, 'cpu:0') + self.assertDeviceEqual(b.initial_value.device, b.device) + + +class GetVariablesCollections(tf.test.TestCase): + + def testVariableCollection(self): + with self.test_session(): + a = tf.contrib.framework.variable('a', [], collections='A') + b = tf.contrib.framework.variable('b', [], collections='B') + self.assertEquals(a, tf.get_collection('A')[0]) + self.assertEquals(b, tf.get_collection('B')[0]) + + def testVariableCollections(self): + with self.test_session(): + a = tf.contrib.framework.variable('a', [], collections=['A', 'C']) + b = tf.contrib.framework.variable('b', [], collections=['B', 'C']) + self.assertEquals(a, tf.get_collection('A')[0]) + self.assertEquals(b, tf.get_collection('B')[0]) + self.assertListEqual([a, b], tf.get_collection('C')) + + def testVariableCollectionsWithArgScope(self): + with self.test_session(): + with tf.contrib.framework.arg_scope([tf.contrib.framework.variable], + collections='A'): + a = tf.contrib.framework.variable('a', []) + b = tf.contrib.framework.variable('b', []) + self.assertListEqual([a, b], tf.get_collection('A')) + + def testVariableCollectionsWithArgScopeNested(self): + with self.test_session(): + with tf.contrib.framework.arg_scope([tf.contrib.framework.variable], + collections='A'): + a = tf.contrib.framework.variable('a', []) + with tf.contrib.framework.arg_scope([tf.contrib.framework.variable], + collections='B'): + b = tf.contrib.framework.variable('b', []) + self.assertEquals(a, tf.get_collection('A')[0]) + self.assertEquals(b, tf.get_collection('B')[0]) + + def testVariableCollectionsWithArgScopeNonNested(self): + with self.test_session(): + with tf.contrib.framework.arg_scope([tf.contrib.framework.variable], + collections='A'): + a = tf.contrib.framework.variable('a', []) + with tf.contrib.framework.arg_scope([tf.contrib.framework.variable], + collections='B'): + b = tf.contrib.framework.variable('b', []) + tf.contrib.framework.variable('c', []) + self.assertListEqual([a], tf.get_collection('A')) + self.assertListEqual([b], tf.get_collection('B')) + + def testVariableRestoreWithArgScopeNested(self): + with self.test_session(): + a = tf.contrib.framework.variable('a', []) + with tf.contrib.framework.arg_scope([tf.contrib.framework.variable], + trainable=False, + collections=['A', 'B']): + b = tf.contrib.framework.variable('b', []) + c = tf.contrib.framework.variable('c', [], trainable=False) + self.assertEquals([a, c], tf.contrib.framework.get_variables_to_restore()) + self.assertEquals([a], tf.trainable_variables()) + self.assertEquals([b], tf.get_collection('A')) + self.assertEquals([b], tf.get_collection('B')) + + +class GetVariablesBySuffixTest(tf.test.TestCase): + + def testGetVariableGivenNameScoped(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + b = tf.contrib.framework.variable('b', [5]) + self.assertEquals([a], + tf.contrib.framework.get_variables_by_suffix('a')) + self.assertEquals([b], + tf.contrib.framework.get_variables_by_suffix('b')) + + def testGetVariableWithScope(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + fooa = tf.contrib.framework.variable('fooa', [5]) + with tf.variable_scope('B'): + a2 = tf.contrib.framework.variable('a', [5]) + matched_variables = tf.contrib.framework.get_variables_by_suffix('a') + self.assertEquals([a, fooa, a2], matched_variables) + matched_variables = tf.contrib.framework.get_variables_by_suffix('/a') + self.assertEquals([a, a2], matched_variables) + matched_variables = tf.contrib.framework.get_variables_by_suffix( + 'a', scope='A') + self.assertEquals([a, fooa], matched_variables) + + def testGetVariableWithoutScope(self): + with self.test_session(): + a = tf.contrib.framework.variable('a', [5]) + fooa = tf.contrib.framework.variable('fooa', [5]) + b_a = tf.contrib.framework.variable('B/a', [5]) + matched_variables = tf.contrib.framework.get_variables_by_suffix('a') + self.assertEquals([a, fooa, b_a], matched_variables) + matched_variables = tf.contrib.framework.get_variables_by_suffix('fooa') + self.assertEquals([fooa], matched_variables) + + +class GetVariablesByNameTest(tf.test.TestCase): + + def testGetVariableGivenNameScoped(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + b = tf.contrib.framework.variable('b', [5]) + self.assertEquals([a], tf.contrib.framework.get_variables_by_name('a')) + self.assertEquals([b], tf.contrib.framework.get_variables_by_name('b')) + + def testGetVariableWithScope(self): + with self.test_session(): + with tf.variable_scope('A'): + a = tf.contrib.framework.variable('a', [5]) + fooa = tf.contrib.framework.variable('fooa', [5]) + with tf.variable_scope('B'): + a2 = tf.contrib.framework.variable('a', [5]) + matched_variables = tf.contrib.framework.get_variables_by_name('a') + self.assertEquals([a, a2], matched_variables) + matched_variables = tf.contrib.framework.get_variables_by_name('fooa') + self.assertEquals([fooa], matched_variables) + matched_variables = tf.contrib.framework.get_variables_by_name('/a') + self.assertEquals([], matched_variables) + matched_variables = tf.contrib.framework.get_variables_by_name('a', + scope='A') + self.assertEquals([a], matched_variables) + + def testGetVariableWithoutScope(self): + with self.test_session(): + a = tf.contrib.framework.variable('a', [5]) + fooa = tf.contrib.framework.variable('fooa', [5]) + b_a = tf.contrib.framework.variable('B/a', [5]) + matched_variables = tf.contrib.framework.get_variables_by_name('a') + for v in matched_variables: + print(v.name) + self.assertEquals([a, b_a], matched_variables) + matched_variables = tf.contrib.framework.get_variables_by_name('fooa') + self.assertEquals([fooa], matched_variables) -if __name__ == "__main__": +if __name__ == '__main__': tf.test.main()