Skip to content

Commit

Permalink
[FRONTEND][Keras] fix reshape (dmlc#493)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazum authored and tqchen committed May 24, 2018
1 parent 29d4f14 commit 1b76ca7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
11 changes: 6 additions & 5 deletions python/nnvm/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,12 @@ def _convert_concat(insym, keras_layer, _):


def _convert_reshape(insym, keras_layer, _):
shape = keras_layer.shape if hasattr(keras_layer, 'shape') \
else keras_layer.target_shape if hasattr(keras_layer, 'target_shape') \
else None
if shape is None:
raise TypeError("No shape attribute in reshape layer: {}".format(keras_layer))
_check_data_format(keras_layer)
ch = keras_layer.input_shape[-1]
assert ch == keras_layer.target_shape[-1], \
"Only supports last dimension in target shape being equal to " \
"the channel number of input tensor."
shape = (-1, ch) + keras_layer.target_shape[:-1]
return _sym.reshape(insym, shape=shape)


Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def test_forward_relu6():
verify_keras_frontend(keras_model)


def test_forward_reshape():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Reshape(target_shape=(32,32,3))(data)
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)


def test_forward_vgg16():
keras_model = keras.applications.vgg16.VGG16(include_top=True, weights=None,
input_shape=(224,224,3), classes=1000)
Expand Down Expand Up @@ -162,6 +170,7 @@ def test_forward_resnet50():
test_forward_separable_conv()
test_forward_upsample()
test_forward_relu6()
test_forward_reshape()

test_forward_vgg16()
test_forward_xception()
Expand Down

0 comments on commit 1b76ca7

Please sign in to comment.