Skip to content

Commit

Permalink
Fix merge_dot tests
Browse files Browse the repository at this point in the history
* Fix merge_dot tests

* Make batch_dot unique

batch_dot is not tensordot! It only accepts one reduce dimension at a
time. Other reduce dimensions should be dome afterwards with K.sum
This means that K.batch_dot will have the same behavior in both
tensorflow and theano. This also means that we have less parenthesis and
less nested lists.

New usage:

merge_mode = 'dot', dot_axes=[axis1, axis2]

Before:

merge_mode = 'dot', dot_axes=[[axis1], [axis2]]

* Backport sign by @the-moliver

* Fix docstrings
  • Loading branch information
EderSantana authored and fchollet committed Apr 3, 2016
1 parent a6fe2ae commit 8b3543f
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 20 deletions.
2 changes: 1 addition & 1 deletion examples/babi_memnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
match = Sequential()
match.add(Merge([input_encoder_m, question_encoder],
mode='dot',
dot_axes=[(2,), (2,)]))
dot_axes=[2, 2]))
# output: (samples, story_maxlen, query_maxlen)
# embed the input into a single vector with size = story_maxlen:
input_encoder_c = Sequential()
Expand Down
37 changes: 33 additions & 4 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,38 @@ def dot(x, y):


def batch_dot(x, y, axes=None):
if axes:
adj_x = None if axes[0][0] == ndim(x) - 1 else True
adj_y = True if axes[1][0] == ndim(y) - 1 else None
'''batchwise dot product
batch_dot results in a tensor with less dimensions than the input.
If the number of dimensions is reduced to 1, we use `expand_dims` to
make sure that ndim is at least 2.
# Example
Assume x = [[1, 2] and y = [[5, 6]
[3, 4]] [7, 8]]
batch_dot(x, y, axes=1) = [[17, 53]] which is the main diagonal
of x.dot(y.T), although we never have to calculate the off-diagonal
elements.
# Arguments
x, y: tensors with ndim >= 2
axes: list (or single) int with target dimensions
# Returns
Tensor with ndim >= 2
'''
if type(axes) == int:
axes = (axes, axes)
if axes is not None:
adj_x = None if axes[0] == ndim(x) - 1 else True
adj_y = True if axes[1] == ndim(y) - 1 else None
else:
adj_x = None
adj_y = None
return tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y)
out = tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y)
if ndim(out) == 1:
out = expand_dims(out, 1)
return out


def transpose(x):
Expand Down Expand Up @@ -256,6 +281,10 @@ def round(x):
return tf.round(x)


def sign(x):
return tf.sign(x)


def pow(x, a):
return tf.pow(x, a)

Expand Down
33 changes: 31 additions & 2 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,35 @@ def dot(x, y):


def batch_dot(x, y, axes=None):
'''batchwise dot product
batch_dot results in a tensor with less dimensions than the input.
If the number of dimensions is reduced to 1, we use `expand_dims` to
make sure that ndim is at least 2.
# Example
Assume x = [[1, 2] and y = [[5, 6]
[3, 4]] [7, 8]]
batch_dot(x, y, axes=1) = [[17, 53]] which is the main diagonal
of x.dot(y.T), although we never have to calculate the off-diagonal
elements.
# Arguments
x, y: tensors with ndim >= 2
axes: list (or single) int with target dimensions
# Returns
Tensor with ndim >= 2
'''
if type(axes) == int:
axes = (axes, axes)
if axes is None:
# behaves like tf.batch_matmul as default
axes = [(x.ndim - 1,), (y.ndim - 2,)]
return T.batched_tensordot(x, y, axes=axes)
axes = [x.ndim - 1, y.ndim - 2]
out = T.batched_tensordot(x, y, axes=axes)
if ndim(out) == 1:
out = expand_dims(out, 1)
return out


def transpose(x):
Expand Down Expand Up @@ -219,6 +244,10 @@ def round(x):
return T.round(x)


def sign(x):
return T.sgn(x)


def pow(x, a):
return T.pow(x, a)

Expand Down
23 changes: 11 additions & 12 deletions keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,8 @@ def __init__(self, layers=None, mode='sum', concat_axis=-1,
self.mode = mode
self.concat_axis = concat_axis
self.dot_axes = dot_axes
if type(self.dot_axes) == int:
self.dot_axes = [self.dot_axes, ] * 2
self._output_shape = output_shape
self.node_indices = node_indices

Expand Down Expand Up @@ -1113,20 +1115,19 @@ def _arguments_validation(self, layers, mode, concat_axis, dot_axes,
if mode == 'dot':
if type(dot_axes) == int:
if dot_axes < 0:
dot_axes = [range(dot_axes % n1, n1), range(dot_axes % n2, n2)]
dot_axes = [dot_axes % n1, dot_axes % n2]
else:
dot_axes = [range(n1 - dot_axes, n2), range(1, dot_axes + 1)]
dot_axes = [n1 - dot_axes, n2-dot_axes]
if type(dot_axes) not in [list, tuple]:
raise Exception('Invalid type for dot_axes - should be a list.')
if len(dot_axes) != 2:
raise Exception('Invalid format for dot_axes - should contain two elements.')
if type(dot_axes[0]) not in [list, tuple, range] or type(dot_axes[1]) not in [list, tuple, range]:
raise Exception('Invalid format for dot_axes - list elements should have type "list" or "tuple".')
for i in range(len(dot_axes[0])):
if shape1[dot_axes[0][i]] != shape2[dot_axes[1][i]]:
raise Exception('Dimension incompatibility using dot mode: ' +
'%s != %s. ' % (shape1[dot_axes[0][i]], shape2[dot_axes[1][i]]) +
'Layer shapes: %s, %s' % (shape1, shape2))
if type(dot_axes[0]) is not int or type(dot_axes[1]) is not int:
raise Exception('Invalid format for dot_axes - list elements should be "int".')
if shape1[dot_axes[0]] != shape2[dot_axes[1]]:
raise Exception('Dimension incompatibility using dot mode: ' +
'%s != %s. ' % (shape1[dot_axes[0]], shape2[dot_axes[1][i]]) +
'Layer shapes: %s, %s' % (shape1, shape2))
elif mode == 'concat':
reduced_inputs_shapes = [list(shape) for shape in input_shapes]
shape_set = set()
Expand Down Expand Up @@ -1242,9 +1243,7 @@ def get_output_shape_for(self, input_shape):
elif self.mode == 'dot':
shape1 = list(input_shapes[0])
shape2 = list(input_shapes[1])
dot_axes = []
for axes in self.dot_axes:
dot_axes.append([index-1 for index in axes])
dot_axes = [a-1 for a in self.dot_axes]
tensordot_output = np.tensordot(np.zeros(tuple(shape1[1:])),
np.zeros(tuple(shape2[1:])),
axes=dot_axes)
Expand Down
1 change: 1 addition & 0 deletions tests/keras/backend/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_elementwise_operations(self):
check_single_tensor_operation('exp', (4, 2))
check_single_tensor_operation('log', (4, 2))
check_single_tensor_operation('round', (4, 2))
check_single_tensor_operation('sign', (4, 2))
check_single_tensor_operation('pow', (4, 2), a=3)
check_single_tensor_operation('clip', (4, 2), min_value=0.4,
max_value=0.6)
Expand Down
2 changes: 1 addition & 1 deletion tests/keras/test_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_merge_dot():
right.add(Activation('relu'))

model = Sequential()
model.add(Merge([left, right], mode='dot', dot_axes=([1], [1])))
model.add(Merge([left, right], mode='dot', dot_axes=[1, 1]))
model.add(Dense(nb_class))
model.add(Activation('softmax'))

Expand Down

0 comments on commit 8b3543f

Please sign in to comment.