From 8b3543fca9d811c638bb72d78601c8564f5465fd Mon Sep 17 00:00:00 2001 From: Eder Santana Date: Sun, 3 Apr 2016 13:03:09 -0400 Subject: [PATCH] Fix merge_dot tests * Fix merge_dot tests * Make batch_dot unique batch_dot is not tensordot! It only accepts one reduce dimension at a time. Other reduce dimensions should be dome afterwards with K.sum This means that K.batch_dot will have the same behavior in both tensorflow and theano. This also means that we have less parenthesis and less nested lists. New usage: merge_mode = 'dot', dot_axes=[axis1, axis2] Before: merge_mode = 'dot', dot_axes=[[axis1], [axis2]] * Backport sign by @the-moliver * Fix docstrings --- examples/babi_memnn.py | 2 +- keras/backend/tensorflow_backend.py | 37 +++++++++++++++++++++++++--- keras/backend/theano_backend.py | 33 +++++++++++++++++++++++-- keras/engine/topology.py | 23 +++++++++-------- tests/keras/backend/test_backends.py | 1 + tests/keras/test_sequential_model.py | 2 +- 6 files changed, 78 insertions(+), 20 deletions(-) diff --git a/examples/babi_memnn.py b/examples/babi_memnn.py index 11f1ab634fc..6eb50c1e6da 100644 --- a/examples/babi_memnn.py +++ b/examples/babi_memnn.py @@ -167,7 +167,7 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen): match = Sequential() match.add(Merge([input_encoder_m, question_encoder], mode='dot', - dot_axes=[(2,), (2,)])) + dot_axes=[2, 2])) # output: (samples, story_maxlen, query_maxlen) # embed the input into a single vector with size = story_maxlen: input_encoder_c = Sequential() diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index a716fb93dd7..588f37541e2 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -125,13 +125,38 @@ def dot(x, y): def batch_dot(x, y, axes=None): - if axes: - adj_x = None if axes[0][0] == ndim(x) - 1 else True - adj_y = True if axes[1][0] == ndim(y) - 1 else None + '''batchwise dot product + batch_dot results in a tensor with less dimensions than the input. + If the number of dimensions is reduced to 1, we use `expand_dims` to + make sure that ndim is at least 2. + + # Example + Assume x = [[1, 2] and y = [[5, 6] + [3, 4]] [7, 8]] + batch_dot(x, y, axes=1) = [[17, 53]] which is the main diagonal + of x.dot(y.T), although we never have to calculate the off-diagonal + elements. + + + # Arguments + x, y: tensors with ndim >= 2 + axes: list (or single) int with target dimensions + + # Returns + Tensor with ndim >= 2 + ''' + if type(axes) == int: + axes = (axes, axes) + if axes is not None: + adj_x = None if axes[0] == ndim(x) - 1 else True + adj_y = True if axes[1] == ndim(y) - 1 else None else: adj_x = None adj_y = None - return tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y) + out = tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y) + if ndim(out) == 1: + out = expand_dims(out, 1) + return out def transpose(x): @@ -256,6 +281,10 @@ def round(x): return tf.round(x) +def sign(x): + return tf.sign(x) + + def pow(x, a): return tf.pow(x, a) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index cee45088338..bdb901812f5 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -127,10 +127,35 @@ def dot(x, y): def batch_dot(x, y, axes=None): + '''batchwise dot product + batch_dot results in a tensor with less dimensions than the input. + If the number of dimensions is reduced to 1, we use `expand_dims` to + make sure that ndim is at least 2. + + # Example + Assume x = [[1, 2] and y = [[5, 6] + [3, 4]] [7, 8]] + batch_dot(x, y, axes=1) = [[17, 53]] which is the main diagonal + of x.dot(y.T), although we never have to calculate the off-diagonal + elements. + + + # Arguments + x, y: tensors with ndim >= 2 + axes: list (or single) int with target dimensions + + # Returns + Tensor with ndim >= 2 + ''' + if type(axes) == int: + axes = (axes, axes) if axes is None: # behaves like tf.batch_matmul as default - axes = [(x.ndim - 1,), (y.ndim - 2,)] - return T.batched_tensordot(x, y, axes=axes) + axes = [x.ndim - 1, y.ndim - 2] + out = T.batched_tensordot(x, y, axes=axes) + if ndim(out) == 1: + out = expand_dims(out, 1) + return out def transpose(x): @@ -219,6 +244,10 @@ def round(x): return T.round(x) +def sign(x): + return T.sgn(x) + + def pow(x, a): return T.pow(x, a) diff --git a/keras/engine/topology.py b/keras/engine/topology.py index 46f9898cb8a..9b95382593f 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -1038,6 +1038,8 @@ def __init__(self, layers=None, mode='sum', concat_axis=-1, self.mode = mode self.concat_axis = concat_axis self.dot_axes = dot_axes + if type(self.dot_axes) == int: + self.dot_axes = [self.dot_axes, ] * 2 self._output_shape = output_shape self.node_indices = node_indices @@ -1113,20 +1115,19 @@ def _arguments_validation(self, layers, mode, concat_axis, dot_axes, if mode == 'dot': if type(dot_axes) == int: if dot_axes < 0: - dot_axes = [range(dot_axes % n1, n1), range(dot_axes % n2, n2)] + dot_axes = [dot_axes % n1, dot_axes % n2] else: - dot_axes = [range(n1 - dot_axes, n2), range(1, dot_axes + 1)] + dot_axes = [n1 - dot_axes, n2-dot_axes] if type(dot_axes) not in [list, tuple]: raise Exception('Invalid type for dot_axes - should be a list.') if len(dot_axes) != 2: raise Exception('Invalid format for dot_axes - should contain two elements.') - if type(dot_axes[0]) not in [list, tuple, range] or type(dot_axes[1]) not in [list, tuple, range]: - raise Exception('Invalid format for dot_axes - list elements should have type "list" or "tuple".') - for i in range(len(dot_axes[0])): - if shape1[dot_axes[0][i]] != shape2[dot_axes[1][i]]: - raise Exception('Dimension incompatibility using dot mode: ' + - '%s != %s. ' % (shape1[dot_axes[0][i]], shape2[dot_axes[1][i]]) + - 'Layer shapes: %s, %s' % (shape1, shape2)) + if type(dot_axes[0]) is not int or type(dot_axes[1]) is not int: + raise Exception('Invalid format for dot_axes - list elements should be "int".') + if shape1[dot_axes[0]] != shape2[dot_axes[1]]: + raise Exception('Dimension incompatibility using dot mode: ' + + '%s != %s. ' % (shape1[dot_axes[0]], shape2[dot_axes[1][i]]) + + 'Layer shapes: %s, %s' % (shape1, shape2)) elif mode == 'concat': reduced_inputs_shapes = [list(shape) for shape in input_shapes] shape_set = set() @@ -1242,9 +1243,7 @@ def get_output_shape_for(self, input_shape): elif self.mode == 'dot': shape1 = list(input_shapes[0]) shape2 = list(input_shapes[1]) - dot_axes = [] - for axes in self.dot_axes: - dot_axes.append([index-1 for index in axes]) + dot_axes = [a-1 for a in self.dot_axes] tensordot_output = np.tensordot(np.zeros(tuple(shape1[1:])), np.zeros(tuple(shape2[1:])), axes=dot_axes) diff --git a/tests/keras/backend/test_backends.py b/tests/keras/backend/test_backends.py index cce82c3ff0d..ab7b188004c 100644 --- a/tests/keras/backend/test_backends.py +++ b/tests/keras/backend/test_backends.py @@ -149,6 +149,7 @@ def test_elementwise_operations(self): check_single_tensor_operation('exp', (4, 2)) check_single_tensor_operation('log', (4, 2)) check_single_tensor_operation('round', (4, 2)) + check_single_tensor_operation('sign', (4, 2)) check_single_tensor_operation('pow', (4, 2), a=3) check_single_tensor_operation('clip', (4, 2), min_value=0.4, max_value=0.6) diff --git a/tests/keras/test_sequential_model.py b/tests/keras/test_sequential_model.py index f337fe70287..9a8349e551b 100644 --- a/tests/keras/test_sequential_model.py +++ b/tests/keras/test_sequential_model.py @@ -281,7 +281,7 @@ def test_merge_dot(): right.add(Activation('relu')) model = Sequential() - model.add(Merge([left, right], mode='dot', dot_axes=([1], [1]))) + model.add(Merge([left, right], mode='dot', dot_axes=[1, 1])) model.add(Dense(nb_class)) model.add(Activation('softmax'))