Skip to content

Commit

Permalink
Add API conversion interface for MaxPooling1D layer (keras-team#5667)
Browse files Browse the repository at this point in the history
  • Loading branch information
jihobak authored and fchollet committed Mar 9, 2017
1 parent 6419d52 commit 8f6d12f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
2 changes: 2 additions & 0 deletions keras/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..engine import Layer
from ..engine import InputSpec
from ..utils import conv_utils
from ..legacy import interfaces


class _Pooling1D(Layer):
Expand Down Expand Up @@ -66,6 +67,7 @@ class MaxPooling1D(_Pooling1D):
3D tensor with shape: `(batch_size, downsampled_steps, features)`.
"""

@interfaces.legacy_maxpooling1d_support
def __init__(self, pool_size=2, strides=None,
padding='valid', **kwargs):
super(MaxPooling1D, self).__init__(pool_size, strides,
Expand Down
44 changes: 44 additions & 0 deletions keras/legacy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,47 @@ def wrapper(*args, **kwargs):

return func(*args, **kwargs)
return wrapper


def legacy_maxpooling1d_support(func):
"""Function wrapper to convert the `MaxPooling1D` constructor from Keras 1 to 2.
# Arguments
func: `__init__` method of `MaxPooling1D`.
# Returns
A constructor conversion wrapper.
"""
@six.wraps(func)
def wrapper(*args, **kwargs):
if len(args) > 2:
# The first entry in `args` is `self`.
raise TypeError('The `MaxPooling1D` layer can have at most '
'one positional argument (the `pool_size` argument).')

# make sure that only keyword argument 'pool_size'(or pool_length' in the legacy interface)
# can be also used as positional argument, which is keyword argument originally.
if 'pool_length' in kwargs:
if len(args) > 1:
raise TypeError('Got both a positional argument '
'and keyword argument for argument '
'`pool_size` '
'(`pool_length` in the legacy interface).')

elif 'pool_size' in kwargs:
if len(args) > 1:
raise TypeError('Got both a positional argument '
'and keyword argument for argument '
'`pool_size`. ')

# Remaining kwargs.
conversions = [
('pool_length', 'pool_size'),
('border_mode', 'padding'),
]
kwargs = convert_legacy_kwargs('MaxPooling1D',
args[1:],
kwargs,
conversions)
return func(*args, **kwargs)
return wrapper
12 changes: 12 additions & 0 deletions tests/keras/legacy/interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,17 @@ def test_dropout_legacy_interface():
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer_1.get_config())
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer_2.get_config())


@keras_test
def test_maxpooling1d_legacy_interface():
old_layer = keras.layers.MaxPool1D(pool_length=2, border_mode='valid', name='maxpool1d')
new_layer = keras.layers.MaxPool1D(pool_size=2, padding='valid', name='maxpool1d')
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config())

old_layer = keras.layers.MaxPool1D(2, padding='valid', name='maxpool1d')
new_layer = keras.layers.MaxPool1D(pool_size=2, padding='valid', name='maxpool1d')
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config())


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 8f6d12f

Please sign in to comment.