Skip to content

Commit

Permalink
Simplify shape code.
Browse files Browse the repository at this point in the history
Fixes: tensorflow#37640
PiperOrigin-RevId: 307820599
Change-Id: I74e5119798420b428353aeaf3047096507601202
  • Loading branch information
MarkDaoust authored and tensorflower-gardener committed Apr 22, 2020
1 parent b644a64 commit 081c7d5
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 32 deletions.
2 changes: 2 additions & 0 deletions tensorflow/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2477,9 +2477,11 @@ tf_py_test(
main = "framework/sparse_tensor_test.py",
python_version = "PY3",
deps = [
":array_ops",
":framework",
":framework_for_generated_wrappers",
":framework_test_lib",
":math_ops",
":platform_test",
"//tensorflow/core:protos_all_py",
],
Expand Down
35 changes: 3 additions & 32 deletions tensorflow/python/framework/sparse_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,38 +132,9 @@ def __init__(self, indices, values, dense_shape):
# is a VariableOp and updating users of SparseTensor.
values = ops.convert_to_tensor(values, name="values")

# Can't check `if context.executing_eagerly()` here because sparse
# placeholders can still be used in eager context, when building a
# functional model.
if isinstance(indices, ops.EagerTensor):
try:
dense_shape = ops.convert_to_tensor(
dense_shape, name="dense_shape", dtype=dtypes.int64)
dense_shape_default = tensor_shape.TensorShape(dense_shape)
except ValueError:
raise ValueError("Unable to create eager SparseTensor. Check that "
"your shape is correctly defined. Eager "
"SparseTensors don't support unknown dimesions.\n"
"got shape:\n {}".format(dense_shape))
else:
if isinstance(dense_shape, ops.Tensor):
dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
else:
dense_shape_default = []
for dim in dense_shape:
if isinstance(dim, ops.Tensor):
# There is code passing lists of constant tensors.
dim = tensor_util.constant_value(dim)
if dim == -1:
# -1 may be passed for unknown shapes.
dim = None

dense_shape_default.append(dim)

dense_shape_default = tensor_shape.TensorShape(dense_shape_default)

dense_shape = ops.convert_to_tensor(
dense_shape, name="dense_shape", dtype=dtypes.int64)
dense_shape = ops.convert_to_tensor(
dense_shape, name="dense_shape", dtype=dtypes.int64)
dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)

self._indices = indices
self._values = values
Expand Down
80 changes: 80 additions & 0 deletions tensorflow/python/framework/sparse_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import googletest

Expand Down Expand Up @@ -124,6 +126,84 @@ def test_convert_sparse(self):
sparse_tensor_value.dense_shape, convertee.dense_shape)


class SparseTensorShapeTest(test_util.TensorFlowTestCase):

def test_simple(self):
indices = [[0, 2]]
values = [1]
dense_shape = [5, 5]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)

self.assertIsInstance(sp.shape, tensor_shape.TensorShape)
self.assertIsInstance(sp.dense_shape, ops.Tensor)
self.assertEqual(sp.shape.as_list(), [5, 5])

def test_unknown_shape(self):

@def_function.function
def my_func(dense_shape):
indices = [[0, 2]]
values = [1]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.as_list(), [None, None])
return sp

my_func.get_concrete_function(
dense_shape=tensor_spec.TensorSpec(
dtype=dtypes.int64, shape=[2,]))

def test_partial_shape(self):

@def_function.function
def my_func(x):
indices = [[0, 2]]
values = [1]
y = ops.convert_to_tensor(3, dtype=dtypes.int64)
dense_shape = [x, y]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.as_list(), [None, 3])
return sp

my_func.get_concrete_function(
x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[]))

def test_neg_shape(self):
indices = [[0, 2]]
values = [1]
dense_shape = [-1, 5]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.as_list(), [None, 5])

def test_unknown_tensor_shape(self):

@def_function.function
def my_func(x):
indices = [[0, 0]]
values = [1]
dense_shape = array_ops.shape(x)
dense_shape = math_ops.cast(dense_shape, dtypes.int64)

sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.as_list(), [None, None])
return sp

my_func.get_concrete_function(
x=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None, None]))

def test_unknown_rank(self):

@def_function.function
def my_func(dense_shape):
indices = [[0, 0]]
values = [1]
sp = sparse_tensor.SparseTensor(indices, values, dense_shape)
self.assertEqual(sp.shape.rank, None)
return sp

my_func.get_concrete_function(
dense_shape=tensor_spec.TensorSpec(dtype=dtypes.int64, shape=[None]))


@test_util.run_all_in_graph_and_eager_modes
class SparseTensorSpecTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
Expand Down

0 comments on commit 081c7d5

Please sign in to comment.