From 1b76ca710deffb70b9811bd464c027d132baaadd Mon Sep 17 00:00:00 2001 From: MORITA Kazutaka Date: Fri, 25 May 2018 01:05:17 +0900 Subject: [PATCH] [FRONTEND][Keras] fix reshape (#493) --- python/nnvm/frontend/keras.py | 11 ++++++----- tests/python/frontend/keras/test_forward.py | 9 +++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/nnvm/frontend/keras.py b/python/nnvm/frontend/keras.py index 0d51487c3..a61fee7e3 100644 --- a/python/nnvm/frontend/keras.py +++ b/python/nnvm/frontend/keras.py @@ -345,11 +345,12 @@ def _convert_concat(insym, keras_layer, _): def _convert_reshape(insym, keras_layer, _): - shape = keras_layer.shape if hasattr(keras_layer, 'shape') \ - else keras_layer.target_shape if hasattr(keras_layer, 'target_shape') \ - else None - if shape is None: - raise TypeError("No shape attribute in reshape layer: {}".format(keras_layer)) + _check_data_format(keras_layer) + ch = keras_layer.input_shape[-1] + assert ch == keras_layer.target_shape[-1], \ + "Only supports last dimension in target shape being equal to " \ + "the channel number of input tensor." + shape = (-1, ch) + keras_layer.target_shape[:-1] return _sym.reshape(insym, shape=shape) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 0147a3e2c..c751b6443 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -134,6 +134,14 @@ def test_forward_relu6(): verify_keras_frontend(keras_model) +def test_forward_reshape(): + data = keras.layers.Input(shape=(32,32,3)) + x = keras.layers.Reshape(target_shape=(32,32,3))(data) + x = keras.layers.GlobalAveragePooling2D()(x) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + + def test_forward_vgg16(): keras_model = keras.applications.vgg16.VGG16(include_top=True, weights=None, input_shape=(224,224,3), classes=1000) @@ -162,6 +170,7 @@ def test_forward_resnet50(): test_forward_separable_conv() test_forward_upsample() test_forward_relu6() + test_forward_reshape() test_forward_vgg16() test_forward_xception()