Skip to content

Commit

Permalink
Improve the support for device functions when using `tf.import_graph_…
Browse files Browse the repository at this point in the history
…def()`.

Previously the device function would run before the inputs to an op
had been added. This yielded incorrect results for some device
functions that depend on the inputs to an op. The fix is to split the
creation of the op from applying its device functions (as we do for
shape functions as well).
Change: 114219160
  • Loading branch information
mrry authored and Vijay Vasudevan committed Feb 9, 2016
1 parent fca72d0 commit f4b0df8
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 13 deletions.
13 changes: 9 additions & 4 deletions tensorflow/python/framework/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,9 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
node.attr[key].CopyFrom(attr_def.default_value)

output_types = _OutputTypes(node, op_dict)
with _MaybeDevice(node.device):
name_to_op[node.name] = g.create_op(
node.op, [], output_types, name=node.name, attrs=node.attr,
compute_shapes=False)
name_to_op[node.name] = g.create_op(
node.op, [], output_types, name=node.name, attrs=node.attr,
compute_shapes=False, compute_device=False)

# 2. Add inputs to the operations.
for node in graph_def.node:
Expand Down Expand Up @@ -313,6 +312,12 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
# may not be available for this op's inputs.
ops.set_shapes_for_outputs(op)

# Apply device functions for this op.
# NOTE(mrry): We do this after configuring the inputs, because
# the result of the device functions may depend on the inputs.
with _MaybeDevice(node.device):
g._apply_device_functions(op) # pylint: disable=protected-access

# Treat unused input mappings as an error, because they are likely to be
# due to a typo.
unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
Expand Down
23 changes: 23 additions & 0 deletions tensorflow/python/framework/importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,29 @@ def testWithDevice(self):
self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu.
self.assertEqual(c.device + '/device:GPU:0', c5.device)

def testWithDeviceFunctionDependingOnInputs(self):
with tf.Graph().as_default() as g:
with tf.device("/job:ps"):
v = tf.Variable(1.0)
unused_assign_op = v.assign(2.0)
unused_assign_2_op = v.assign(3.0)
unused_add_t = v + v
gdef = g.as_graph_def()

# We'll use the following device function to observe ops with two inputs.
ops_with_two_inputs = []
def input_counter(op):
if any(in_t.dtype.is_ref_dtype for in_t in op.inputs):
ops_with_two_inputs.append(op)
return ""

with tf.Graph().as_default() as g:
with tf.device(input_counter):
tf.import_graph_def(gdef)

# We expect to see the initializer, two assign operations, and the add op.
self.assertEqual(4, len(ops_with_two_inputs))

def testGradient(self):
with tf.Graph().as_default() as g:
inputs = tf.placeholder(tf.float32, shape=[None, 100], name="input")
Expand Down
25 changes: 16 additions & 9 deletions tensorflow/python/framework/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,7 +1961,7 @@ def _add_function(self, function_def):
# Helper functions to create operations.
def create_op(self, op_type, inputs, dtypes,
input_types=None, name=None, attrs=None, op_def=None,
compute_shapes=True):
compute_shapes=True, compute_device=True):
"""Creates an `Operation` in this graph.
This is a low-level interface for creating an `Operation`. Most
Expand Down Expand Up @@ -1989,6 +1989,8 @@ def create_op(self, op_type, inputs, dtypes,
the operation will have.
compute_shapes: (Optional.) If True, shape inference will be performed
to compute the shapes of the outputs.
compute_device: (Optional.) If True, device functions will be executed
to compute the device property of the Operation.
Raises:
TypeError: if any of the inputs is not a `Tensor`.
Expand Down Expand Up @@ -2037,14 +2039,8 @@ def create_op(self, op_type, inputs, dtypes,
set_shapes_for_outputs(ret)
self._add_op(ret)
self._record_op_seen_by_control_dependencies(ret)
# Apply any device functions in reverse order, so that the most recently
# pushed function has the first chance to apply a device to the op.
# We apply here because the result can depend on the Operation's
# signature, which is computed in the Operation constructor.
for device_function in reversed(self._device_function_stack):
if device_function is None:
break
ret._set_device(device_function(ret))
if compute_device:
self._apply_device_functions(ret)
return ret

def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
Expand Down Expand Up @@ -2551,6 +2547,17 @@ def matmul_on_gpu(n):
finally:
self._device_function_stack.pop()

def _apply_device_functions(self, op):
"""Applies the current device function stack to the given operation."""
# Apply any device functions in reverse order, so that the most recently
# pushed function has the first chance to apply a device to the op.
# We apply here because the result can depend on the Operation's
# signature, which is computed in the Operation constructor.
for device_function in reversed(self._device_function_stack):
if device_function is None:
break
op._set_device(device_function(op))

class _ControlDependenciesController(object):
"""Context manager for `control_dependencies()`."""

Expand Down

0 comments on commit f4b0df8

Please sign in to comment.