Skip to content

Commit

Permalink
Override "/" and "*" for "sparse <op> dense" in Python.
Browse files Browse the repository at this point in the history
Usage in Python:

    > result_sparse_t = sparse_t / dense_t  # cwise div
    > result_sparse_t = sparse_t * dense_t  # cwise mul

These are counterparts of "tf.(true_)div()" and "tf.mul()", respectively.
Change: 121860748
  • Loading branch information
concretevitamin authored and tensorflower-gardener committed May 9, 2016
1 parent 374e673 commit 32ecd83
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 30 deletions.
49 changes: 29 additions & 20 deletions tensorflow/python/framework/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,30 @@
from tensorflow.python.util import compat


def _override_helper(clazz_object, operator, func):
"""Overrides (string) operator on Tensors to call func.
Args:
clazz_object: the class to override for; either Tensor or SparseTensor.
operator: the string name of the operator to override.
func: the function that replaces the overriden operator.
Raises:
ValueError: If operator has already been overwritten,
or if operator is not allowed to be overwritten.
"""
existing = getattr(clazz_object, operator, None)
if existing is not None:
# Check to see if this is a default method-wrapper or slot wrapper which
# will be true for the comparison operators.
if not isinstance(existing, type(object.__lt__)):
raise ValueError("operator %s cannot be overwritten again on class %s." %
(operator, clazz_object))
if operator not in Tensor.OVERLOADABLE_OPERATORS:
raise ValueError("Overriding %s is disallowed" % operator)
setattr(clazz_object, operator, func)


def _convert_stack(stack):
"""Converts a stack extracted using _extract_stack() to a traceback stack.
Expand Down Expand Up @@ -408,25 +432,7 @@ def __eq__(self, other):

@staticmethod
def _override_operator(operator, func):
"""Overrides (string) operator on Tensors to call func.
Args:
operator: the string name of the operator to override.
func: the function that replaces the overriden operator.
Raises:
ValueError: If operator has already been overwritten,
or if operator is not allowed to be overwritten.
"""
existing = getattr(Tensor, operator, None)
if existing is not None:
# Check to see if this is a default method-wrapper or slot wrapper which
# will be true for the comparison operators.
if not isinstance(existing, type(object.__lt__)):
raise ValueError("operator %s cannot be overwritten again." % operator)
if operator not in Tensor.OVERLOADABLE_OPERATORS:
raise ValueError("Overriding %s is disallowed" % operator)
setattr(Tensor, operator, func)
_override_helper(Tensor, operator, func)

def __iter__(self):
"""Dummy method to prevent iteration. Do not call.
Expand Down Expand Up @@ -982,12 +988,15 @@ def eval(self, feed_dict=None, session=None):
Returns:
A `SparseTensorValue` object.
"""
indices, values, shape = _eval_using_default_session(
[self.indices, self.values, self.shape], feed_dict, self.graph, session)
return SparseTensorValue(indices, values, shape)

@staticmethod
def _override_operator(operator, func):
_override_helper(SparseTensor, operator, func)


SparseTensorValue = collections.namedtuple("SparseTensorValue",
["indices", "values", "shape"])
Expand Down
35 changes: 35 additions & 0 deletions tensorflow/python/kernel_tests/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,5 +388,40 @@ def testGradient(self):
self.assertLess(err, 1e-3)


class SparseMathOpsTest(test_util.TensorFlowTestCase):

def _check(self, result_tensor, result_np, input_sp_t):
self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
self.assertAllEqual(input_sp_t.shape.eval(), result_tensor.shape.eval())

res_densified = sparse_ops.sparse_to_dense(result_tensor.indices,
result_tensor.shape,
result_tensor.values).eval()
self.assertAllEqual(res_densified, result_np)

def testCwiseDivAndMul(self):
np.random.seed(1618)
sp_shapes = [(10, 10, 10), (5, 5), (1618,)]
dense_shapes = [(10, 10, 1), (5, 5), (1,)]

with self.test_session(use_gpu=False):
for dtype in [np.float32, np.float64, np.int32, np.int64]:
for sp_shape, dense_shape in zip(sp_shapes, dense_shapes):
sp_vals_np = np.random.rand(*sp_shape).astype(dtype) + 1
dense_vals_np = np.random.rand(*dense_shape).astype(dtype) + 1
sp_t, unused_nnz = _sparsify(sp_vals_np)
sp_t_densified = sparse_ops.sparse_tensor_to_dense(sp_t).eval()
dense_t = tf.constant(dense_vals_np)

self._check(sp_t / dense_t, sp_t_densified / dense_vals_np, sp_t)
# Check commutative.
self._check(sp_t * dense_t, sp_t_densified * dense_vals_np, sp_t)
self._check(dense_t * sp_t, sp_t_densified * dense_vals_np, sp_t)

if dtype in [np.int32, np.int64]:
res = sp_t / dense_t # should invoke "__truediv__"
self.assertEqual(res.values.eval().dtype, np.float64)


if __name__ == "__main__":
googletest.main()
4 changes: 4 additions & 0 deletions tensorflow/python/ops/math_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,3 +721,7 @@ def _CrossGrad(op, grad):
u = op.inputs[0]
v = op.inputs[1]
return (math_ops.cross(v, grad), math_ops.cross(grad, u))


ops.NoGradient("SparseDenseCwiseMul")
ops.NoGradient("SparseDenseCwiseDiv")
87 changes: 77 additions & 10 deletions tensorflow/python/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import state_ops
# go/tf-wildcard-import
Expand Down Expand Up @@ -505,31 +506,44 @@ def to_bfloat16(x, name="ToBFloat16"):
ops.Tensor._override_operator("__invert__", gen_math_ops.logical_not)


def _OverrideBinaryOperatorHelper(func, op_name):
def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
"""Register operators with different tensor and scalar versions.
If `clazz_object` is `SparseTensor`, assumes `func` takes `(sp_indices,
sp_values, sp_shape, dense)` and outputs `(new_sp_values)`.
Args:
func: the operator
op_name: name of the operator being overridden
clazz_object: class to override for. Either `Tensor` or `SparseTensor`.
"""

def binary_op_wrapper(x, y):
with ops.op_scope([x, y], None, op_name) as name:
assert isinstance(x, ops.Tensor)
y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
if not isinstance(y, ops.SparseTensor):
y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
return func(x, y, name=name)

ops.Tensor._override_operator("__%s__" % op_name, binary_op_wrapper)
del binary_op_wrapper
def binary_op_wrapper_sparse(sp_x, y):
with ops.op_scope([sp_x, y], None, op_name) as name:
y = ops.convert_to_tensor(y, dtype=sp_x.dtype.base_dtype, name="y")
return ops.SparseTensor(sp_x.indices, func(sp_x.indices, sp_x.values,
sp_x.shape, y, name=name),
sp_x.shape)

def r_binary_op_wrapper(y, x):
with ops.op_scope([x, y], None, op_name) as name:
assert isinstance(y, ops.Tensor)
x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
return func(x, y, name=name)

ops.Tensor._override_operator("__r%s__" % op_name, r_binary_op_wrapper)
del r_binary_op_wrapper
if clazz_object is ops.Tensor:
clazz_object._override_operator("__%s__" % op_name, binary_op_wrapper)
del binary_op_wrapper
clazz_object._override_operator("__r%s__" % op_name, r_binary_op_wrapper)
del r_binary_op_wrapper
else:
clazz_object._override_operator("__%s__" % op_name,
binary_op_wrapper_sparse)
del binary_op_wrapper_sparse


# Conversion table for __truediv__. None entries mean no conversion required.
Expand All @@ -546,6 +560,31 @@ def r_binary_op_wrapper(y, x):
}


# NOTE: the support of "sparse (true)div dense" is currently not baked in into
# "tf.(true_)div()". Until such an API decision is made, the supported usage is
# to explicitly use the "/" operator to invoke either truediv or div.
def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None):
"""Internal helper function for 'sp_t / dense_t'."""
with ops.op_scope([sp_indices, sp_values, sp_shape, y],
name, "truediv") as name:
sp_values = ops.convert_to_tensor(sp_values, name="sp_values")
y = ops.convert_to_tensor(y, name="y")
x_dtype = sp_values.dtype.base_dtype
y_dtype = y.dtype.base_dtype
if x_dtype != y_dtype:
raise TypeError("x and y must have the same dtype, got %r != %r" %
(x_dtype, y_dtype))
try:
dtype = _TRUEDIV_TABLE[x_dtype]
except KeyError:
raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
if dtype is not None:
sp_values = cast(sp_values, dtype)
y = cast(y, dtype)
return gen_sparse_ops.sparse_dense_cwise_div(sp_indices, sp_values,
sp_shape, y, name=name)


def truediv(x, y, name=None):
"""Divides x / y elementwise, always producing floating point results.
Expand Down Expand Up @@ -626,9 +665,29 @@ def floordiv(x, y, name=None):
return gen_math_ops.div(x, y, name=name)


def _mul_dispatch(x, y, name=None):
"""Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse"."""
is_tensor_y = isinstance(y, ops.Tensor)
if is_tensor_y:
return gen_math_ops.mul(x, y, name=name)
else:
assert isinstance(y, ops.SparseTensor) # Case: Dense * Sparse.
new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
y.shape, x, name)
return ops.SparseTensor(y.indices, new_vals, y.shape)


_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_div, "div",
ops.SparseTensor)
_OverrideBinaryOperatorHelper(_sparse_dense_truediv, "truediv",
ops.SparseTensor)
_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul",
ops.SparseTensor)


_OverrideBinaryOperatorHelper(gen_math_ops.add, "add")
_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub")
_OverrideBinaryOperatorHelper(gen_math_ops.mul, "mul")
_OverrideBinaryOperatorHelper(_mul_dispatch, "mul")
_OverrideBinaryOperatorHelper(gen_math_ops.div, "div")
_OverrideBinaryOperatorHelper(truediv, "truediv")
_OverrideBinaryOperatorHelper(floordiv, "floordiv")
Expand Down Expand Up @@ -1368,6 +1427,14 @@ def _BroadcastShape(op):
return [tensor_shape.TensorShape(return_dims)]


@ops.RegisterShape("SparseDenseCwiseMul")
@ops.RegisterShape("SparseDenseCwiseDiv")
def _SparseDenseBinaryOpShape(op): # pylint: disable=invalid-name
"""Common shape for 'sparse <binary cwise op> dense -> sparse' operators."""
nnz = op.inputs[1].get_shape()[0]
return [tensor_shape.TensorShape(nnz)]


@ops.RegisterShape("AddN")
def _AddNShape(op):
merged_shape = tensor_shape.unknown_shape()
Expand Down

0 comments on commit 32ecd83

Please sign in to comment.