Skip to content

Commit

Permalink
Use convert_to_tensor to do type/shape check instead of make_tensor_p…
Browse files Browse the repository at this point in the history
…roto,

as convert_to_tensor handles a Tensor in addition to other Python values.
Change: 119078171
  • Loading branch information
keveman authored and tensorflower-gardener committed Apr 5, 2016
1 parent e2c76ad commit 84476b2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
9 changes: 7 additions & 2 deletions tensorflow/python/kernel_tests/concat_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,23 @@ def testRandom(self):
def testInvalidConcatDimTypeAndShape(self):
a = tf.Variable(tf.constant(1.0, shape=[1]))
b = tf.Variable(tf.constant(2.0, shape=[1]))
with self.assertRaises(TypeError):
with self.assertRaises(ValueError):
tf.concat(a, b)
with self.assertRaises(TypeError):
tf.concat(4.2, 1)
with self.assertRaises(TypeError):
with self.assertRaises(ValueError):
tf.concat(a, 1)
with self.assertRaises(TypeError):
tf.concat(a, [a, b])
with self.assertRaises(ValueError):
tf.concat([3], [a, b])
with self.assertRaises(ValueError):
tf.concat(0, [])
# An integer tensor for shape dim should throw no error.
tf.concat(tf.constant(0, shape=[]), 1)
# A non-scalar tensor for shape should throw ValueError.
with self.assertRaises(ValueError):
tf.concat(tf.constant(0, shape=[1]), 1)

def _testGradientsSimple(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
Expand Down
16 changes: 10 additions & 6 deletions tensorflow/python/ops/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,16 @@ def concat(concat_dim, values, name="concat"):
values = [values]
# TODO(mrry): Change to return values?
if len(values) == 1: # Degenerate case of one tensor.
# Make a throwaway call to make_tensor_proto to make sure
# that concat_dim is of the correct type.
# TODO(keveman): Extract the type and shape checks out of make_tensor_proto
# in to a standalone function.
tensor_util.make_tensor_proto(concat_dim, dtype=dtypes.int32, shape=[])
return identity(values[0], name=name)
# Make a throwaway call to convert_to_tensor to make sure
# that concat_dim is of the correct type, and make sure that
# the returned tensor is a scalar.
# TODO(keveman): Implement a standalone type and shape checker.
with ops.name_scope(name) as scope:
ops.convert_to_tensor(concat_dim,
name="concat_dim",
dtype=dtypes.int32).get_shape(
).assert_is_compatible_with(tensor_shape.scalar())
return identity(values[0], name=scope)
return gen_array_ops._concat(concat_dim=concat_dim,
values=values,
name=name)
Expand Down

0 comments on commit 84476b2

Please sign in to comment.