Skip to content

Commit

Permalink
Raise a better error message when a list element is not convertible t…
Browse files Browse the repository at this point in the history
…o Tensor.

Previously, we hit an erroneous assertion when converting a list
argument to a list of tensors. This change makes it clearer what
caused the error when one or more of the arguments is an object that
is not convertible to `tf.Tensor`.

Fixes tensorflow#2385.
Change: 122600354
  • Loading branch information
mrry authored and tensorflower-gardener committed May 18, 2016
1 parent 34f1d68 commit dd3b812
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
35 changes: 23 additions & 12 deletions tensorflow/python/ops/op_def_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,25 +408,36 @@ def apply_op(self, op_type_name, name=None, **keywords):
values = ops.convert_n_to_tensor(
values, name=input_arg.name, dtype=dtype if dtype else None,
as_ref=input_arg.is_ref)
if input_arg.number_attr and len(
set(v.dtype.base_dtype for v in values)) > 1:
raise TypeError() # All types should match.
except (TypeError, ValueError):
assert dtype is not None, "Should not fail if dtype is None"
assert input_arg.number_attr, "Should be number_attr case"
# What types does the conversion function think values have?
values = ops.convert_n_to_tensor(values, as_ref=input_arg.is_ref)
observed = ", ".join(v.dtype.base_dtype.name for v in values)
observed_types = []
for value in values:
try:
converted_value = ops.convert_to_tensor(
value, as_ref=input_arg.is_ref)
observed_types.append(converted_value.dtype.base_dtype.name)
except (TypeError, ValueError):
observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
observed = ", ".join(observed_types)

prefix = (
"Tensors in list passed to '%s' of '%s' Op have types [%s]" %
(input_name, op_type_name, observed))
if input_arg.type != types_pb2.DT_INVALID:
raise TypeError("%s that do not match expected type %s." %
(prefix, dtype.name))
elif input_arg.type_attr in attrs:
raise TypeError("%s that do not match type %s inferred from "
"earlier arguments." %
(prefix, dtype.name))
if input_arg.number_attr:
if input_arg.type != types_pb2.DT_INVALID:
raise TypeError("%s that do not match expected type %s." %
(prefix, dtype.name))
elif input_arg.type_attr in attrs:
raise TypeError("%s that do not match type %s inferred from "
"earlier arguments." %
(prefix, dtype.name))
else:
raise TypeError("%s that don't all match." % prefix)
else:
raise TypeError("%s that don't all match." % prefix)
raise TypeError("%s that are invalid." % prefix)

types = [x.dtype for x in values]
inputs.extend(values)
Expand Down
27 changes: 25 additions & 2 deletions tensorflow/python/ops/op_def_library_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ def testTypeList(self):
"Expected list for 'a' "
"argument to 'TypeList' Op, not ")

with self.assertRaises(TypeError) as cm:
self._lib.apply_op("TypeList", a=[self.Tensor(dtypes.int32), None])
self.assertStartsWith(str(cm.exception),
"Tensors in list passed to 'a' of 'TypeList' Op "
"have types [int32, <NOT CONVERTIBLE TO TENSOR>]")

def testTypeListTwice(self):
self._add_op("name: 'TypeListTwice' "
"input_arg { name: 'a' type_list_attr: 'T' } "
Expand Down Expand Up @@ -957,6 +963,16 @@ def testNPolymorphicIn(self):
attr { key: 'N' value { i: 2 } }
""", op.node_def)

op = self._lib.apply_op("NPolymorphicIn",
a=[self.Tensor(dtypes.float32, name="y"),
self.Tensor(dtypes.float32_ref, name="z")],
name="r")
self.assertProtoEquals("""
name: 'r' op: 'NPolymorphicIn' input: 'y' input: 'z'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'N' value { i: 2 } }
""", op.node_def)

with self.assertRaises(ValueError) as cm:
self._lib.apply_op("NPolymorphicIn", a=[99])
self.assertEqual(str(cm.exception),
Expand All @@ -966,8 +982,8 @@ def testNPolymorphicIn(self):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NPolymorphicIn", a=[38, "bar"])
self.assertEqual(str(cm.exception),
"All tensors passed to 'a' of 'NPolymorphicIn' "
"Op must have the same type.")
"Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
"have types [int32, string] that don't all match.")

with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NPolymorphicIn",
Expand All @@ -976,6 +992,13 @@ def testNPolymorphicIn(self):
"Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
"have types [int32, string] that don't all match.")

with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NPolymorphicIn", a=[38, None])
self.assertEqual(str(cm.exception),
"Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
"have types [int32, <NOT CONVERTIBLE TO TENSOR>] that "
"don't all match.")

with self.assertRaises(TypeError) as cm:
self._lib.apply_op("NPolymorphicIn",
a=["abcd", self.Tensor(dtypes.int32)])
Expand Down

0 comments on commit dd3b812

Please sign in to comment.