Skip to content

Commit

Permalink
Shape inference for Theano backend (keras-team#5618)
Browse files Browse the repository at this point in the history
* Added shape inference for flatten

* Added shape inference for repeat_elements

* Added shape inference + minute optimisation for batch_flatten

* Added shape inference for squeeze

* Added shape inference for permute_dimensions

* Added shape inference for repeat

* Added shape inference for expand_dims

* Added shape inference for spatial_2d_padding

* Removind #TODO tag + whitespace fix

* Added shape inference for transpose

* Added shape inference for batch_dot

* Minor fix in batch_dot shape inference

* PEP8 Fix

* Shape inference tests to check_single_tensor_operation and check_double_tensor_operation

* PEP8 Fix + Added test for tile

* Fixed flatten shape inference

* Fixed squeeze shape inference

* Added batch_flatten to test_shape_operations
abhaikollara authored and fchollet committed Mar 7, 2017
1 parent 3b66014 commit e99eac2
Showing 2 changed files with 74 additions and 28 deletions.
80 changes: 57 additions & 23 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
@@ -352,7 +352,6 @@ def batch_dot(x, y, axes=None):
output_shape = (100, 30)
"""
# TODO: `keras_shape` inference.
if isinstance(axes, int):
axes = (axes, axes)
if axes is None:
@@ -361,12 +360,26 @@ def batch_dot(x, y, axes=None):
out = T.batched_tensordot(x, y, axes=axes)
if ndim(out) == 1:
out = expand_dims(out, 1)

if hasattr(x, '_keras_shape') and hasattr(y, '_keras_shape'):
shape = []
for axis in range(len(x._keras_shape)):
if axis != axes[0]:
shape.append(x._keras_shape[axis])
for axis in range(1, len(y._keras_shape)):
if axis != axes[1]:
shape.append(y._keras_shape[axis])
if len(shape) == 1:
shape.append(1) # Expand dims if ndim == 1
out._keras_shape = tuple(shape)
return out


def transpose(x):
# TODO: `keras_shape` inference.
return T.transpose(x)
y = T.transpose(x)
if hasattr(x, '_keras_shape'):
y._keras_shape = tuple(reversed(x._keras_shape))
return y


def gather(reference, indices):
@@ -673,9 +686,11 @@ def permute_dimensions(x, pattern):
pattern should be a tuple or list of
dimension indices, e.g. [0, 2, 1].
"""
# TODO: `keras_shape` inference.
pattern = tuple(pattern)
return x.dimshuffle(pattern)
y = x.dimshuffle(pattern)
if hasattr(x, '_keras_shape'):
y._keras_shape = tuple(np.asarray(x._keras_shape)[list(pattern)])
return y


def repeat_elements(x, rep, axis):
@@ -684,8 +699,12 @@ def repeat_elements(x, rep, axis):
If x has shape (s1, s2, s3) and axis=1, the output
will have shape (s1, s2 * rep, s3).
"""
# TODO: `keras_shape` inference.
return T.repeat(x, rep, axis=axis)
y = T.repeat(x, rep, axis=axis)
if hasattr(x, '_keras_shape'):
y._keras_shape = list(x._keras_shape)
y._keras_shape[axis] = x._keras_shape[axis] * rep
y._keras_shape = tuple(y._keras_shape)
return y


def resize_images(X, height_factor, width_factor, data_format):
@@ -695,7 +714,6 @@ def resize_images(X, height_factor, width_factor, data_format):
by a factor of (height_factor, width_factor). Both factors should be
positive integers.
"""
# TODO: `keras_shape` inference.
if data_format == 'channels_first':
output = repeat_elements(X, height_factor, axis=2)
output = repeat_elements(output, width_factor, axis=3)
@@ -715,7 +733,6 @@ def resize_volumes(X, depth_factor, height_factor, width_factor, data_format):
by a factor of (depth_factor, height_factor, width_factor).
Both factors should be positive integers.
"""
# TODO: `keras_shape` inference.
if data_format == 'channels_first':
output = repeat_elements(X, depth_factor, axis=2)
output = repeat_elements(output, height_factor, axis=3)
@@ -736,10 +753,15 @@ def repeat(x, n):
If x has shape (samples, dim) and n=2,
the output will have shape (samples, 2, dim).
"""
# TODO: `keras_shape` inference.
assert x.ndim == 2
x = x.dimshuffle((0, 'x', 1))
return T.extra_ops.repeat(x, n, axis=1)
y = x.dimshuffle((0, 'x', 1))
y = T.extra_ops.repeat(y, n, axis=1)
if hasattr(x, '_keras_shape'):
shape = list(x._keras_shape)
shape.insert(1, n)
y._keras_shape = tuple(shape)

return y


def arange(start, stop=None, step=1, dtype='int32'):
@@ -761,40 +783,51 @@ def tile(x, n):


def flatten(x):
# TODO: `keras_shape` inference.
return T.flatten(x)
y = T.flatten(x)
if hasattr(x, '_keras_shape'):
y._keras_shape = (np.prod(x._keras_shape), )
return y


def batch_flatten(x):
"""Turn a n-D tensor into a 2D tensor where
the first dimension is conserved.
"""
# TODO: `keras_shape` inference.
x = T.reshape(x, (x.shape[0], T.prod(x.shape) // x.shape[0]))
return x
y = T.reshape(x, (x.shape[0], T.prod(x.shape[1:])))
if hasattr(x, '_keras_shape'):
y._keras_shape = (x._keras_shape[0], np.prod(x._keras_shape[1:]))
return y


def expand_dims(x, axis=-1):
"""Add a 1-sized dimension at index "dim".
"""
# TODO: `keras_shape` inference.
pattern = [i for i in range(x.type.ndim)]
if axis < 0:
if x.type.ndim == 0:
axis = 0
else:
axis = axis % x.type.ndim + 1
pattern.insert(axis, 'x')
return x.dimshuffle(pattern)
y = x.dimshuffle(pattern)
if hasattr(x, '_keras_shape'):
shape = list(x._keras_shape)
shape.insert(axis, 1)
y._keras_shape = tuple(shape)
return y


def squeeze(x, axis):
"""Remove a 1-dimension from the tensor at index "axis".
"""
# TODO: `keras_shape` inference.
shape = list(x.shape)
shape.pop(axis)
return T.reshape(x, tuple(shape))
y = T.reshape(x, tuple(shape))
if hasattr(x, '_keras_shape'):
kshape = list(x._keras_shape)
kshape.pop(axis)
y._keras_shape = tuple(kshape)
return y


def temporal_padding(x, padding=(1, 1)):
@@ -822,7 +855,6 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
"""Pad the 2nd and 3rd dimensions of a 4D tensor
with "padding[0]" and "padding[1]" (resp.) zeros left and right.
"""
# TODO: `keras_shape` inference.
assert len(padding) == 2
assert len(padding[0]) == 2
assert len(padding[1]) == 2
@@ -857,7 +889,9 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
slice(None))
else:
raise ValueError('Invalid data_format:', data_format)
return T.set_subtensor(output[indices], x)
y = T.set_subtensor(output[indices], x)
y._keras_shape = output_shape
return y


def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
22 changes: 17 additions & 5 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
@@ -21,11 +21,14 @@ def check_single_tensor_operation(function_name, input_shape, **kwargs):
xth = KTH.variable(val)
xtf = KTF.variable(val)

zth = KTH.eval(getattr(KTH, function_name)(xth, **kwargs))
_zth = getattr(KTH, function_name)(xth, **kwargs)
zth = KTH.eval(_zth)
ztf = KTF.eval(getattr(KTF, function_name)(xtf, **kwargs))

assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-05)
if hasattr(_zth, '_keras_shape'):
assert _zth._keras_shape == zth.shape


def check_two_tensor_operation(function_name, x_input_shape,
@@ -40,11 +43,14 @@ def check_two_tensor_operation(function_name, x_input_shape,
yth = KTH.variable(yval)
ytf = KTF.variable(yval)

zth = KTH.eval(getattr(KTH, function_name)(xth, yth, **kwargs))
_zth = getattr(KTH, function_name)(xth, yth, **kwargs)
zth = KTH.eval(_zth)
ztf = KTF.eval(getattr(KTF, function_name)(xtf, ytf, **kwargs))

assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-05)
if hasattr(_zth, '_keras_shape'):
assert _zth._keras_shape == zth.shape


def check_composed_tensor_operations(first_function_name, first_function_args,
@@ -116,6 +122,7 @@ def test_shape_operations(self):
pattern=(2, 0, 1))
check_single_tensor_operation('repeat', (4, 1), n=3)
check_single_tensor_operation('flatten', (4, 1))
check_single_tensor_operation('batch_flatten', (20, 2, 5))
check_single_tensor_operation('expand_dims', (4, 3), axis=-1)
check_single_tensor_operation('expand_dims', (4, 3, 2), axis=1)
check_single_tensor_operation('squeeze', (4, 3, 1), axis=2)
@@ -134,15 +141,17 @@ def test_repeat_elements(self):

for rep_axis in range(ndims):
np_rep = np.repeat(arr, reps, axis=rep_axis)
th_rep = KTH.eval(
KTH.repeat_elements(arr_th, reps, axis=rep_axis))
th_z = KTH.repeat_elements(arr_th, reps, axis=rep_axis)
th_rep = KTH.eval(th_z)
tf_rep = KTF.eval(
KTF.repeat_elements(arr_tf, reps, axis=rep_axis))

assert th_rep.shape == np_rep.shape
assert tf_rep.shape == np_rep.shape
assert_allclose(np_rep, th_rep, atol=1e-05)
assert_allclose(np_rep, tf_rep, atol=1e-05)
if hasattr(th_z, '_keras_shape'):
assert th_z._keras_shape == th_rep.shape

def test_tile(self):
shape = (3, 4)
@@ -151,9 +160,12 @@ def test_tile(self):
arr_tf = KTF.variable(arr)

n = (2, 1)
th_rep = KTH.eval(KTH.tile(arr_th, n))
th_z = KTH.tile(arr_th, n)
th_rep = KTH.eval(th_z)
tf_rep = KTF.eval(KTF.tile(arr_tf, n))
assert_allclose(tf_rep, th_rep, atol=1e-05)
if hasattr(th_z, '_keras_shape'):
assert th_z._keras_shape == th_rep.shape

def test_value_manipulation(self):
val = np.random.random((4, 2))

0 comments on commit e99eac2

Please sign in to comment.