diff --git a/keras/layers/core.py b/keras/layers/core.py index d3074a41eb4..7ab8e9953b0 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -1376,7 +1376,7 @@ def output_shape(self): else: output_shape_func = marshal.loads(self._output_shape) output_shape_func = types.FunctionType(output_shape_func, globals()) - shape = output_shape_func(self.previous.output_shape) + shape = output_shape_func(self.input_shape) if type(shape) not in {list, tuple}: raise Exception('output_shape function must return a tuple') return tuple(shape)