Skip to content

Commit

Permalink
Relax a somewhat too conservative assertion in WhileV2
Browse files Browse the repository at this point in the history
pfor adds some identities / launders resources through function calls; unless we want to trace through and verify that they're not changing the resource, it's probably better to just trust.

PiperOrigin-RevId: 380628480
Change-Id: I57d5baeb55f8aac1125bad8f4bd2f7675a47886a
  • Loading branch information
allenlavoie authored and tensorflower-gardener committed Jun 21, 2021
1 parent 7959338 commit 3a132c4
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
30 changes: 21 additions & 9 deletions tensorflow/python/ops/control_flow_util_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ def resource_input_index(tensor_name, input_names, node_defs, functions):
output_idx = int(output_idx)
node_def = node_defs[op_name]

def _extract_input_index(function_attribute_name):
func_name = node_def.attr[function_attribute_name].func.name
fdef = functions[func_name].definition
output_arg_name = fdef.signature.output_arg[output_idx].name
output_tensor_name = fdef.ret[output_arg_name]
return resource_input_index(
output_tensor_name, [arg.name for arg in fdef.signature.input_arg],
{ndef.name: ndef for ndef in fdef.node_def}, functions)

if node_def.op in ("Identity", "While"):
# Captured resources occur at the same index in the lists of inputs and
# outputs of a while or identity op. So we lookup the input of `tensor.op`
Expand All @@ -199,21 +208,24 @@ def resource_input_index(tensor_name, input_names, node_defs, functions):
# gradients. `tensor_name` is one of these outputs from a nested
# function call, so recursively find the corresponding input in the
# nested FunctionDef.
func_name = node_def.attr["f"].func.name
fdef = functions[func_name].definition
output_arg_name = fdef.signature.output_arg[output_idx].name
output_tensor_name = fdef.ret[output_arg_name]
input_index = resource_input_index(
output_tensor_name, [arg.name for arg in fdef.signature.input_arg],
{ndef.name: ndef for ndef in fdef.node_def}, functions)
tensor_name = node_def.input[input_index]
tensor_name = node_def.input[_extract_input_index("f")]
elif node_def.op in ("If", "StatelessIf"):
input_index = _extract_input_index("then_branch")
if input_index != _extract_input_index("else_branch"):
raise AssertionError(
("Expected cond branches ({} op) to each have the same "
"input->output mapping of resources.").format(node_def.op))
tensor_name = node_def.input[
# Ignore the `cond` input; the function inputs come after.
input_index + 1]
else:
# We assume there are no other ops types that will "forward" resource
# handles like this, so all other handles must have been created by the
# op. (Note that cond_v2 wraps resource handle outputs in optionals,
# which we'll end up accumulating).
raise ValueError("Taking gradient of a while loop which creates "
"a resource in its body is not supported: %s" % op_name)
"a resource in its body is not supported: %s (%s)"
% (op_name, node_def.op))

return input_names.index(tensor_name)

Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/ops/parallel_for/control_flow_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,20 @@ def log_prob(x):
v_log_prob, (x,), delta=1e-3)
self.assertAllClose(theoretical, numerical, rtol=1e-2)

def test_scan_captured_variable(self):
if not context.executing_eagerly():
self.skipTest("Test only written for 2.x")
v = variables.Variable(math_ops.range(10, dtype=dtypes.float32))

def loop_fn(idx):
del idx
return functional_ops.scan_v2(lambda _, i: array_ops.gather(v, i),
elems=math_ops.range(v.shape[0]),
initializer=0.0)
with backprop.GradientTape() as tape:
result = pfor_control_flow_ops.pfor(loop_fn, 2)
self.assertAllClose([2.] * 10, tape.gradient(result, v))


@test_util.run_all_in_graph_and_eager_modes
class NestedControlFlowTest(PForTestCase):
Expand Down
21 changes: 16 additions & 5 deletions tensorflow/python/ops/while_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,19 +1239,30 @@ def _resource_capture_helper(self, tensor):
"""
assert tensor.dtype == dtypes.resource

forward_graph_input_names = [t.name for t in self._forward_graph.inputs]
forward_graph_name_to_opdef = {
op.name: op.node_def for op in self._forward_graph.get_operations()}
index = util.resource_input_index(
tensor.name, [t.name for t in self._forward_graph.inputs],
{op.name: op.node_def for op in self._forward_graph.get_operations()},
tensor.name, forward_graph_input_names,
forward_graph_name_to_opdef,
self._forward_graph._functions)

input_placeholder = self._forward_graph.inputs[index]
tensor_in_outer_graph = self._forward_graph._while.inputs[index]

assert input_placeholder.dtype == dtypes.resource
assert tensor_in_outer_graph.dtype == dtypes.resource
# This must be a loop invariant.
assert input_placeholder is self._forward_graph.outputs[index], (
"Resource tensors must be loop invariants %s." % tensor_in_outer_graph)
# This must be a loop invariant. However, infrastructure
# (e.g. tf.vectorized_map) may insert identity nodes, function calls, conds,
# etc. which take and return the resource tensor unmodified; this means that
# the Python objects may differ.
if index != util.resource_input_index(
self._forward_graph.outputs[index].name, forward_graph_input_names,
forward_graph_name_to_opdef,
self._forward_graph._functions):
raise AssertionError(
"Resource tensors must be loop invariants %s."
% tensor_in_outer_graph)

self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
tensor_in_outer_graph)
Expand Down

0 comments on commit 3a132c4

Please sign in to comment.