From 46a3f9443d4c6741177e0ff970c6b60488c9398c Mon Sep 17 00:00:00 2001 From: Abhai Kollara Dilip Date: Thu, 9 Mar 2017 09:28:57 +0530 Subject: [PATCH] Shape inference for tile and gather (#5635) * Added tile shape inference * Added shape inference for gather * Added test for gather * Fixed test_gather * PEP8 fix * Fixed gather test --- keras/backend/theano_backend.py | 21 +++++++++++++++++---- tests/keras/backend/backend_test.py | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index d0794b9b18f3..5aab86f08db3 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -388,8 +388,11 @@ def gather(reference, indices): Return: a tensor of same type as reference. """ - # TODO: `keras_shape` inference. - return reference[indices] + y = reference[indices] + if hasattr(reference, '_keras_shape') and hasattr(indices, '_keras_shape'): + l = indices._keras_shape[0] + y._keras_shape = (l,) + reference._keras_shape[1:] + return y # ELEMENT-WISE OPERATIONS @@ -778,8 +781,18 @@ def arange(start, stop=None, step=1, dtype='int32'): def tile(x, n): - # TODO: `keras_shape` inference. - return T.tile(x, n) + y = T.tile(x, n) + if hasattr(x, '_keras_shape'): + xshape = np.asarray(x._keras_shape) + n = np.asarray(n) + diff = len(xshape) - len(n) + if diff > 0: + n = np.append([1] * diff, n) + else: + xshape = np.append([1] * -diff, xshape) + y._keras_shape = tuple(xshape * n) + + return y def flatten(x): diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index aab0c9b7ab9e..ecfb96265911 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -167,6 +167,24 @@ def test_tile(self): if hasattr(th_z, '_keras_shape'): assert th_z._keras_shape == th_rep.shape + def test_gather(self): + shape = (10, 2, 3) + ref = np.arange(np.prod(shape)).reshape(shape) + ref_th = KTH.variable(ref) + ref_tf = KTF.variable(ref) + + inds = [1, 3, 7, 9] + inds_th = KTH.variable(inds, dtype='int32') + inds_tf = KTF.variable(inds, dtype='int32') + th_z = KTH.gather(ref_th, inds_th) + th_result = KTH.eval(th_z) + tf_result = KTF.eval(KTF.gather(ref_tf, inds_tf)) + + assert_allclose(tf_result, th_result, atol=1e-05) + + if hasattr(th_z, '_keras_shape'): + assert th_z._keras_shape == th_result.shape + def test_value_manipulation(self): val = np.random.random((4, 2)) xth = KTH.variable(val)