Skip to content

Commit

Permalink
String equality op.
Browse files Browse the repository at this point in the history
Change: 113410311
  • Loading branch information
ebrevdo authored and Vijay Vasudevan committed Jan 30, 2016
1 parent b672b36 commit b2a64c3
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/cwise_op_equal_to.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"

namespace tensorflow {
REGISTER8(BinaryOp, CPU, "Equal", functor::equal_to, float, double, uint8, int8,
int16, int32, int64, complex64);
REGISTER9(BinaryOp, CPU, "Equal", functor::equal_to, float, double, uint8, int8,
int16, int32, int64, complex64, string);
#if GOOGLE_CUDA
REGISTER6(BinaryOp, GPU, "Equal", functor::equal_to, float, double, uint8, int8,
int16, int64);
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/cwise_op_not_equal_to.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"

namespace tensorflow {
REGISTER8(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, double,
uint8, int8, int16, int32, int64, complex64);
REGISTER9(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, double,
uint8, int8, int16, int32, int64, complex64, string);
#if GOOGLE_CUDA
REGISTER6(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, double,
uint8, int8, int16, int64);
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/core/kernels/cwise_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ struct SelectFunctor<CPUDevice, T> {
REGISTER(OP, D, N, F, T0)
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
REGISTER(OP, D, N, F, T0)
#define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
REGISTER(OP, D, N, F, T0)
#else // !defined(__ANDROID_TYPES_SLIM__)
#define REGISTER2(OP, D, N, F, T0, T1) \
REGISTER(OP, D, N, F, T0) \
Expand All @@ -406,6 +408,9 @@ struct SelectFunctor<CPUDevice, T> {
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
REGISTER4(OP, D, N, F, T4, T5, T6, T7)
#define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
REGISTER4(OP, D, N, F, T5, T6, T7, T8)
#endif // defined(__ANDROID_TYPES_SLIM__)

} // end namespace tensorflow
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/ops/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ Returns the truth value of (x >= y) element-wise.
#define COMPARISON() \
Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \
"T: {float, double, uint8, int8, int16, int32, int64, complex64, " \
"quint8, qint8, qint32}")
"quint8, qint8, qint32, string}")

REGISTER_OP("Equal")
.COMPARISON()
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/ops/ops.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -2568,6 +2568,7 @@ op {
type: DT_QUINT8
type: DT_QINT8
type: DT_QINT32
type: DT_STRING
}
}
}
Expand Down Expand Up @@ -4938,6 +4939,7 @@ op {
type: DT_QUINT8
type: DT_QINT8
type: DT_QINT32
type: DT_STRING
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/python/kernel_tests/cwise_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,16 @@ def testComplex64Basic(self):
self._compareCpu(x, y, np.multiply, _MUL)
self._compareCpu(x, y + 0.1, np.true_divide, _TRUEDIV)

def testStringComparison(self):
x = np.array([["abc", "bh"], ["c", ""]])
y = np.array([["abc", "bh"], ["def", "hi"]])
with self.test_session(use_gpu=False) as sess:
cmp_eq = tf.equal(x, y)
cmp_not_eq = tf.not_equal(x, y)
values = sess.run([cmp_eq, cmp_not_eq])
self.assertAllEqual([[True, True], [False, False]], values[0])
self.assertAllEqual([[False, False], [True, True]], values[1])

def testString(self):
x = np.array([["x_0_0", "x_0_1", "x_0_2"],
["x_1_0", "x_1_1", "x_1_2"],
Expand Down

0 comments on commit b2a64c3

Please sign in to comment.