From 839d3dd0b598ecf6f2e93b9f9068c2e21fa1c97b Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Fri, 9 Mar 2018 06:52:33 +0530 Subject: [PATCH] bug fix - run_internal_graph() (#9599) * bug fix * cleanup * readability++ * add test for case lambda multi out no mask * pep8 * pep8 --- keras/engine/topology.py | 23 +++++++++++++++++------ tests/keras/layers/core_test.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/keras/engine/topology.py b/keras/engine/topology.py index 028538546c3..707ac02011d 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -2210,7 +2210,6 @@ def run_internal_graph(self, inputs, masks=None): for node in nodes: # This is always a single layer, never a list. layer = node.outbound_layer - reference_input_tensors = node.input_tensors reference_output_tensors = node.output_tensors @@ -2234,8 +2233,12 @@ def run_internal_graph(self, inputs, masks=None): if 'mask' not in kwargs: kwargs['mask'] = computed_mask output_tensors = _to_list(layer.call(computed_tensor, **kwargs)) - output_masks = _to_list(layer.compute_mask(computed_tensor, - computed_mask)) + output_masks = layer.compute_mask(computed_tensor, + computed_mask) + if output_masks is None: + output_masks = [None for _ in output_tensors] + else: + output_masks = _to_list(output_masks) computed_tensors = [computed_tensor] computed_masks = [computed_mask] else: @@ -2245,14 +2248,22 @@ def run_internal_graph(self, inputs, masks=None): if 'mask' not in kwargs: kwargs['mask'] = computed_masks output_tensors = _to_list(layer.call(computed_tensors, **kwargs)) - output_masks = _to_list(layer.compute_mask(computed_tensors, - computed_masks)) - + output_masks = layer.compute_mask(computed_tensors, + computed_masks) + if output_masks is None: + output_masks = [None for _ in output_tensors] + else: + output_masks = _to_list(output_masks) # Apply activity regularizer if any: if hasattr(layer, 'activity_regularizer') and layer.activity_regularizer is not None: regularization_losses = [layer.activity_regularizer(x) for x in output_tensors] layer.add_loss(regularization_losses, computed_tensors) + if len(output_masks) != len(output_tensors): + raise Exception('Layers should have equal number of output tensors ' + 'and output masks. Layer ' + str(layer.name) + ' has' + ' ' + str(len(output_tensors)) + ' output tensors and' + ' ' + str(len(output_masks)) + ' output masks.') # Update model updates and losses: # Keep track of updates that depend on the inputs # (e.g. BN updates). diff --git a/tests/keras/layers/core_test.py b/tests/keras/layers/core_test.py index 647e5c0c57e..907a54f5d00 100644 --- a/tests/keras/layers/core_test.py +++ b/tests/keras/layers/core_test.py @@ -169,6 +169,36 @@ def mask(inputs, mask=None): test_multiple_outputs() + # test layer with multiple outputs and no + # explicit mask + def test_multiple_outputs_no_mask(): + def func(x): + return [x * 0.2, x * 0.3] + + def output_shape(input_shape): + return [input_shape, input_shape] + + i = layers.Input(shape=(64, 64, 3)) + o = layers.Lambda(function=func, + output_shape=output_shape)(i) + + assert o[0]._keras_shape == (None, 64, 64, 3) + assert o[1]._keras_shape == (None, 64, 64, 3) + + o = layers.add(o) + model = Model(i, o) + + i2 = layers.Input(shape=(64, 64, 3)) + o2 = model(i2) + model2 = Model(i2, o2) + + x = np.random.random((4, 64, 64, 3)) + out = model2.predict(x) + assert out.shape == (4, 64, 64, 3) + assert_allclose(out, x * 0.2 + x * 0.3, atol=1e-4) + + test_multiple_outputs_no_mask() + # test serialization with function def f(x): return x + 1