Skip to content

Commit

Permalink
Avoid extra zero padding of strings in py_func
Browse files Browse the repository at this point in the history
Change: 127223683
girving authored and tensorflower-gardener committed Jul 12, 2016
1 parent 5a3a9b1 commit 7c1ef65
Showing 2 changed files with 28 additions and 3 deletions.
6 changes: 6 additions & 0 deletions tensorflow/python/kernel_tests/py_func_test.py
Original file line number Diff line number Diff line change
@@ -104,6 +104,12 @@ def read_and_return_strings(x, y):
z, = tf.py_func(read_and_return_strings, [x, y], [tf.string])
self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])

def testStringPadding(self):
correct = [b"this", b"is", b"a", b"test"]
with self.test_session():
s, = tf.py_func(lambda: [correct], [], [tf.string])
self.assertAllEqual(s.eval(), correct)

def testLarge(self):
with self.test_session() as sess:
x = tf.zeros([1000000], dtype=np.float32)
25 changes: 22 additions & 3 deletions tensorflow/python/ops/script_ops.py
Original file line number Diff line number Diff line change
@@ -53,6 +53,26 @@ def remove(self, token):
"""Removes the registered function corresponding to `token`."""
self._funcs.pop(token, None)

@staticmethod
def _convert(value):
"""Converts an arg to numpy, avoiding dangerous string and unicode dtypes.
Numpy pads with zeros when using string and unicode dtypes if different
components of a tensor have different lengths. This is bad: ignoring the
padding is wrong for text data, and removing the padding is wrong for binary
data. To avoid this bug, we redo the conversion using an object dtype.
Args:
value: Value to convert to a numpy array.
Returns:
A numpy array.
"""
result = np.asarray(value, order="C")
if result.dtype.char in "SU" and result is not value:
return np.asarray(value, order="C", dtype=object)
return result

def __call__(self, token, args):
"""Calls the registered function for `token` with args."""
func = self._funcs[token]
@@ -62,10 +82,9 @@ def __call__(self, token, args):
# Ensures that we return either a single numpy array or a list of numpy
# arrays.
if isinstance(ret, (tuple, list)):
ret = [np.array(x, order="C") for x in ret]
return [self._convert(x) for x in ret]
else:
ret = np.array(ret, order="C")
return ret
return self._convert(ret)

def size(self):
"""Returns how many functions are currently registered."""

0 comments on commit 7c1ef65

Please sign in to comment.