Skip to content

Commit

Permalink
Merge : Serialize output mask; Enable user arguments for callable mode (
Browse files Browse the repository at this point in the history
keras-team#4445)

* Update topology.py

* Update topology.py

* Update topology.py

* white space fix

* indentation fix

* add tests

* fix all tests

* add arguments arg to merge

* space after period

* add test with arguments

* add test with arguments for lambda layer too

* pep8 fixes

* fix tf test

* try fixing tf test; again

* bug fix

* finally
  • Loading branch information
farizrahman4u authored and fchollet committed Nov 23, 2016
1 parent 7bd5c86 commit 509d6d8
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 12 deletions.
42 changes: 35 additions & 7 deletions keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
import copy
import os
import inspect
from six.moves import zip

from .. import backend as K
Expand Down Expand Up @@ -1129,14 +1130,16 @@ class Merge(Layer):
'''
def __init__(self, layers=None, mode='sum', concat_axis=-1,
dot_axes=-1, output_shape=None, output_mask=None,
node_indices=None, tensor_indices=None, name=None):
arguments={}, node_indices=None, tensor_indices=None,
name=None):
self.layers = layers
self.mode = mode
self.concat_axis = concat_axis
self.dot_axes = dot_axes
self._output_shape = output_shape
self.node_indices = node_indices
self._output_mask = output_mask
self.arguments = arguments

# Layer parameters.
self.inbound_nodes = []
Expand Down Expand Up @@ -1239,9 +1242,10 @@ def call(self, inputs, mask=None):
'(at least 2). Got: ' + str(inputs))
# Case: "mode" is a lambda or function.
if hasattr(self.mode, '__call__'):
# TODO: consider making it possible to
# pass custom arguments to lambda.
arguments = {}
arguments = self.arguments
arg_spec = inspect.getargspec(self.mode)
if 'mask' in arg_spec.args:
arguments['mask'] = mask
return self.mode(inputs, **arguments)

if self.mode == 'sum' or self.mode == 'ave':
Expand Down Expand Up @@ -1338,7 +1342,7 @@ def get_output_shape_for(self, input_shape):
raise Exception('The Merge layer ' + self.name +
' has a callable `mode` argument, ' +
'and we cannot infer its output shape because ' +
'no `output_shape` argument was provided.' +
'no `output_shape` argument was provided. ' +
'Make sure to pass a shape tuple (or a callable) ' +
'`output_shape` to Merge.')
# Pre-defined merge modes.
Expand Down Expand Up @@ -1421,13 +1425,26 @@ def get_config(self):
output_shape = self._output_shape
output_shape_type = 'raw'

if isinstance(self._output_mask, python_types.LambdaType):
output_mask = func_dump(self._output_mask)
output_mask_type = 'lambda'
elif callable(self._output_mask):
output_mask = self._output_mask.__name__
output_mask_type = 'function'
else:
output_mask = self._output_mask
output_mask_type = 'raw'

return {'name': self.name,
'mode': mode,
'mode_type': mode_type,
'concat_axis': self.concat_axis,
'dot_axes': self.dot_axes,
'output_shape': output_shape,
'output_shape_type': output_shape_type}
'output_shape_type': output_shape_type,
'output_mask': output_mask,
'output_mask_type': output_mask_type,
'arguments': self.arguments}

@classmethod
def from_config(cls, config):
Expand All @@ -1447,13 +1464,22 @@ def from_config(cls, config):
else:
output_shape = config['output_shape']

output_mask_type = config.pop('output_mask_type')
if output_mask_type == 'function':
output_mask = globals()[config['output_mask']]
elif output_mask_type == 'lambda':
output_mask = func_load(config['output_mask'], globs=globals())
else:
output_mask = config['output_mask']

config['mode'] = mode
config['output_shape'] = output_shape
config['output_mask'] = output_mask
return super(Merge, cls).from_config(config)


def merge(inputs, mode='sum', concat_axis=-1,
dot_axes=-1, output_shape=None, output_mask=None, name=None):
dot_axes=-1, output_shape=None, output_mask=None, arguments={}, name=None):
'''Functional merge, to apply to Keras tensors (NOT layers).
Returns a Keras tensor.
Expand Down Expand Up @@ -1504,6 +1530,7 @@ def merge(inputs, mode='sum', concat_axis=-1,
dot_axes=dot_axes,
output_shape=output_shape,
output_mask=output_mask,
arguments=arguments,
node_indices=node_indices,
tensor_indices=tensor_indices,
name=name)
Expand All @@ -1514,6 +1541,7 @@ def merge(inputs, mode='sum', concat_axis=-1,
dot_axes=dot_axes,
output_shape=output_shape,
output_mask=output_mask,
arguments=arguments,
name=name)
return merge_layer(inputs)

Expand Down
85 changes: 80 additions & 5 deletions tests/keras/layers/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_masking():

@keras_test
def test_merge():
from keras.layers import Input, merge, Merge
from keras.layers import Input, merge, Merge, Masking
from keras.models import Model

# test modes: 'sum', 'mul', 'concat', 'ave', 'cos', 'dot'.
Expand Down Expand Up @@ -53,7 +53,8 @@ def test_merge():
input_b = Input(shape=input_shapes[1][1:])
merged = merge([input_a, input_b],
mode=lambda tup: K.concatenate([tup[0], tup[1]]),
output_shape=lambda tup: (tup[0][:-1],) + (tup[0][-1] + tup[1][-1],))
output_shape=lambda tup: tup[0][:-1] + (tup[0][-1] + tup[1][-1],))
model = Model([input_a, input_b], merged)
expected_output_shape = model.get_output_shape_for(input_shapes)
actual_output_shape = model.predict(inputs).shape
assert expected_output_shape == actual_output_shape
Expand All @@ -65,17 +66,18 @@ def test_merge():
# test function with output_shape function
def fn_mode(tup):
x, y = tup
return K.concatenate([x, y])
return K.concatenate([x, y], axis=1)

def fn_output_shape(tup):
s1, s2 = tup
return (s1[:-1],) + (s1[-1] + s2[-1],)
return (s1[0], s1[1] + s2[1]) + s1[2:]

input_a = Input(shape=input_shapes[0][1:])
input_b = Input(shape=input_shapes[1][1:])
merged = merge([input_a, input_b],
mode=fn_mode,
output_shape=fn_output_shape)
model = Model([input_a, input_b], merged)
expected_output_shape = model.get_output_shape_for(input_shapes)
actual_output_shape = model.predict(inputs).shape
assert expected_output_shape == actual_output_shape
Expand All @@ -84,6 +86,74 @@ def fn_output_shape(tup):
model = Model.from_config(config)
model.compile('rmsprop', 'mse')

# test function with output_mask function
# time dimension is required for masking
input_shapes = [(4, 3, 2), (4, 3, 2)]
inputs = [np.random.random(shape) for shape in input_shapes]

def fn_output_mask(tup):
x_mask, y_mask = tup
return K.concatenate([x_mask, y_mask])

input_a = Input(shape=input_shapes[0][1:])
input_b = Input(shape=input_shapes[1][1:])
a = Masking()(input_a)
b = Masking()(input_b)
merged = merge([a, b], mode=fn_mode, output_shape=fn_output_shape, output_mask=fn_output_mask)
model = Model([input_a, input_b], merged)
expected_output_shape = model.get_output_shape_for(input_shapes)
actual_output_shape = model.predict(inputs).shape
assert expected_output_shape == actual_output_shape

config = model.get_config()
model = Model.from_config(config)
model.compile('rmsprop', 'mse')

mask_inputs = (np.zeros(input_shapes[0][:-1]), np.ones(input_shapes[1][:-1]))
expected_mask_output = np.concatenate(mask_inputs, axis=-1)
mask_input_placeholders = [K.placeholder(shape=input_shape[:-1]) for input_shape in input_shapes]
mask_output = model.layers[-1]._output_mask(mask_input_placeholders)
assert np.all(K.function(mask_input_placeholders, [mask_output])(mask_inputs)[0] == expected_mask_output)

# test lambda with output_mask lambda
input_a = Input(shape=input_shapes[0][1:])
input_b = Input(shape=input_shapes[1][1:])
a = Masking()(input_a)
b = Masking()(input_b)
merged = merge([a, b], mode=lambda tup: K.concatenate([tup[0], tup[1]], axis=1),
output_shape=lambda tup: (tup[0][0], tup[0][1] + tup[1][1]) + tup[0][2:],
output_mask=lambda tup: K.concatenate([tup[0], tup[1]]))
model = Model([input_a, input_b], merged)
expected_output_shape = model.get_output_shape_for(input_shapes)
actual_output_shape = model.predict(inputs).shape
assert expected_output_shape == actual_output_shape

config = model.get_config()
model = Model.from_config(config)
model.compile('rmsprop', 'mse')

mask_output = model.layers[-1]._output_mask(mask_input_placeholders)
assert np.all(K.function(mask_input_placeholders, [mask_output])(mask_inputs)[0] == expected_mask_output)

# test with arguments
input_shapes = [(3, 2), (3, 2)]
inputs = [np.random.random(shape) for shape in input_shapes]

def fn_mode(tup, a, b):
x, y = tup
return x * a + y * b

input_a = Input(shape=input_shapes[0][1:])
input_b = Input(shape=input_shapes[1][1:])
merged = merge([input_a, input_b], mode=fn_mode, output_shape=lambda s: s[0], arguments={'a': 0.7, 'b': 0.3})
model = Model([input_a, input_b], merged)
output = model.predict(inputs)

config = model.get_config()
model = Model.from_config(config)

assert np.all(model.predict(inputs) == output)


@keras_test
def test_merge_mask_2d():
Expand Down Expand Up @@ -156,7 +226,7 @@ def test_dropout():
layer_test(core.SpatialDropout1D,
kwargs={'p': 0.5},
input_shape=(2, 3, 4))

layer_test(core.SpatialDropout2D,
kwargs={'p': 0.5},
input_shape=(2, 3, 4, 5))
Expand Down Expand Up @@ -216,6 +286,11 @@ def test_lambda():
kwargs={'function': lambda x: x + 1},
input_shape=(3, 2))

layer_test(Lambda,
kwargs={'function': lambda x, a, b: x * a + b,
'arguments': {'a': 0.6, 'b': 0.4}},
input_shape=(3, 2))

# test serialization with function
def f(x):
return x + 1
Expand Down

0 comments on commit 509d6d8

Please sign in to comment.