Skip to content

Commit

Permalink
Added model_variable and helpers to manage variables.
Browse files Browse the repository at this point in the history
Change: 122727793
  • Loading branch information
A. Unique TensorFlower authored and tensorflower-gardener committed May 19, 2016
1 parent 63c29c8 commit 144855b
Show file tree
Hide file tree
Showing 5 changed files with 906 additions and 54 deletions.
31 changes: 29 additions & 2 deletions tensorflow/contrib/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,40 @@
@@assert_same_float_dtype
@@assert_scalar_int
@@convert_to_tensor_or_sparse_tensor
@@local_variable
@@get_graph_from_inputs
@@is_numeric_tensor
@@is_non_decreasing
@@is_strictly_increasing
@@reduce_sum_n
@@safe_embedding_lookup_sparse
@@with_shape
@@with_same_shape
@@get_graph_from_inputs
## Arg_Scope
@@arg_scope
@@add_arg_scope
@@has_arg_scope
@@arg_scoped_arguments
## Variables
@@add_model_variable
@@assert_global_step
@@assert_or_get_global_step
@@create_global_step
@@get_global_step
@@get_or_create_global_step
@@get_local_variables
@@get_model_variables
@@get_unique_variable
@@get_variables_by_name
@@get_variables_by_suffix
@@get_variables_to_restore
@@get_variables
@@local_variable
@@model_variable
@@variable
@@VariableDeviceChooser
"""

from __future__ import absolute_import
Expand Down
30 changes: 7 additions & 23 deletions tensorflow/contrib/framework/python/framework/tensor_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================

"""Tensor utility functions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand All @@ -27,14 +26,16 @@
from tensorflow.python.ops import variables

__all__ = [
'assert_same_float_dtype', 'assert_scalar_int',
'convert_to_tensor_or_sparse_tensor', 'local_variable', 'reduce_sum_n',
'with_shape', 'with_same_shape',
]
'assert_same_float_dtype',
'assert_scalar_int',
'convert_to_tensor_or_sparse_tensor',
'reduce_sum_n',
'with_shape',
'with_same_shape']


def _assert_same_base_type(items, expected_type=None):
"""Asserts all items are of the same base type.
r"""Asserts all items are of the same base type.
Args:
items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
Expand Down Expand Up @@ -110,23 +111,6 @@ def assert_scalar_int(tensor):
return tensor


# TODO(ptucker): Move to tf.variables?
def local_variable(initial_value, validate_shape=True, name=None):
"""Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection.
Args:
initial_value: See variables.Variable.__init__.
validate_shape: See variables.Variable.__init__.
name: See variables.Variable.__init__.
Returns:
New variable.
"""
return variables.Variable(
initial_value, trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
validate_shape=validate_shape, name=name)


def reduce_sum_n(tensors, name=None):
"""Reduce tensors to a scalar sum.
Expand Down
11 changes: 4 additions & 7 deletions tensorflow/contrib/framework/python/ops/arg_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,17 @@
@tf.contrib.add_arg_scope
def conv2d(*args, **kwargs)
@@arg_scope
@@add_arg_scope
@@has_arg_scope
@@arg_scoped_arguments
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import functools

__all__ = ['arg_scope', 'add_arg_scope',
'has_arg_scope', 'arg_scoped_arguments']
__all__ = ['arg_scope',
'add_arg_scope',
'has_arg_scope',
'arg_scoped_arguments']

_ARGSTACK = [{}]

Expand Down
Loading

0 comments on commit 144855b

Please sign in to comment.