From a8654769c1faf6327b715edae614eb48775394a1 Mon Sep 17 00:00:00 2001 From: anj-s <32556631+anj-s@users.noreply.github.com> Date: Tue, 24 Apr 2018 16:28:41 -0700 Subject: [PATCH] 1.8r Cherrypick request-cherrypicks_30740: Fix for dropped metrics in evaluate function for Keras models. (#18799) --- .../keras/_impl/keras/engine/training.py | 29 ++------- .../_impl/keras/engine/training_eager.py | 39 ++++-------- .../_impl/keras/engine/training_eager_test.py | 11 ++-- .../keras/_impl/keras/engine/training_test.py | 26 ++++++++ .../_impl/keras/engine/training_utils.py | 62 +++++++++++++++++++ 5 files changed, 109 insertions(+), 58 deletions(-) diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 71de657da81b92..2b72e0e33dd909 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -276,6 +276,8 @@ def compile(self, self.metrics_names.append(self.output_names[i] + '_loss') self.nested_metrics = training_utils.collect_metrics(metrics, self.output_names) + with K.name_scope('metrics'): + training_utils.populate_metric_names(self) self._feed_sample_weight_modes = [] for i in range(len(self.outputs)): self._feed_sample_weight_modes.append(None) @@ -462,7 +464,6 @@ def compile(self, output_weighted_metrics = nested_weighted_metrics[i] def handle_metrics(metrics, weights=None): - metric_name_prefix = 'weighted_' if weights is not None else '' for metric in metrics: if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): @@ -489,39 +490,19 @@ def handle_metrics(metrics, weights=None): metric_fn = metrics_module.categorical_accuracy elif metric in ('crossentropy', 'ce'): metric_fn = metrics_module.categorical_crossentropy - if metric in ('accuracy', 'acc'): - suffix = 'acc' - elif metric in ('crossentropy', 'ce'): - suffix = 'ce' weighted_metric_fn = training_utils.weighted_masked_objective( metric_fn) - metric_name = metric_name_prefix + suffix else: metric_fn = metrics_module.get(metric) weighted_metric_fn = training_utils.weighted_masked_objective( metric_fn) - # Get metric name as string - if hasattr(metric_fn, 'name'): - metric_name = metric_fn.name - else: - metric_name = metric_fn.__name__ - metric_name = metric_name_prefix + metric_name - + metric_name = training_utils.get_base_metric_name( + metric, weighted=weights is not None) with K.name_scope(metric_name): metric_result = weighted_metric_fn( y_true, y_pred, weights=weights, mask=masks[i]) - # Append to self.metrics_names, self.metric_tensors, - # self.stateful_metric_names - if len(self.output_names) > 1: - metric_name = '%s_%s' % (self.output_names[i], metric_name) - # Dedupe name - j = 1 - base_metric_name = metric_name - while metric_name in self.metrics_names: - metric_name = '%s_%d' % (base_metric_name, j) - j += 1 - self.metrics_names.append(metric_name) + training_utils.add_metric_name(self, metric_name, i) self.metrics_tensors.append(metric_result) # Keep track of state updates created by diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py index 695669d9ee1566..ad239d6151e02a 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py @@ -100,7 +100,7 @@ def _eager_metrics_fn(model, outputs, targets): metric_names.append(metric_name) metric_results.append(backend.mean(metric_result)) - return metric_names, metric_results + return metric_results def _model_loss(model, inputs, targets, sample_weights=None, training=False): @@ -151,7 +151,12 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): with backend.name_scope(model.output_names[i] + '_loss'): output_loss = weighted_masked_fn( targets[i], outs[i], weights, mask=mask) - loss_metrics.append(backend.mean(output_loss)) + # If the number of outputs is 1 then we don't append the loss metric + # associated with each model output. When there are multiple outputs + # associated with a model, each output's loss is calculated and returned + # as part of the loss_metrics. + if len(model.outputs) > 1: + loss_metrics.append(backend.mean(output_loss)) loss_weight = model.loss_weights_list[i] if total_loss is None: @@ -274,7 +279,7 @@ def train_on_batch(model, inputs, targets, sample_weights=None): model, inputs, targets, sample_weights=sample_weights, training=True) if not isinstance(outs, list): outs = [outs] - _, metrics_results = _eager_metrics_fn( + metrics_results = _eager_metrics_fn( model, outs, targets) if not isinstance(loss, list): loss = [loss] @@ -304,7 +309,7 @@ def test_on_batch(model, inputs, targets, sample_weights=None): model, inputs, targets, sample_weights=sample_weights, training=False) if not isinstance(outs, list): outs = [outs] - _, metrics_results = _eager_metrics_fn( + metrics_results = _eager_metrics_fn( model, outs, targets) if not isinstance(loss, list): loss = [loss] @@ -498,34 +503,12 @@ def fit_loop( for l, o in zip(out_labels, outs): batch_logs[l] = o # Required for Eager mode - metrics_names, metrics_results = _eager_metrics_fn( - model, outs, targets_batch) + metrics_results = _eager_metrics_fn(model, outs, targets_batch) batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss)) - # TODO(anjalisridhar): Move this to compile to avoid duplicate code. - # In graph mode we set the metric names in compile. However in - # Eager mode we calculate the metrics for each batch in fit_loop. - # We could calculate the metric names and functions in compile. - # This would avoid setting the callback parameters separately. - # We need to do this for the first iteration alone - for m in metrics_names: - if m not in callback_metrics: - callback_metrics.append(m) - - callbacks.set_params({ - 'batch_size': batch_size, - 'epochs': epochs, - 'steps': steps_per_epoch, - 'samples': num_train_samples, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics or [], - }) - for k, v in zip(model.metrics_names, [backend.mean(loss)] + loss_metrics + metrics_results): batch_logs[k] = tensor_util.constant_value(v) - callbacks.on_batch_end(batch_index, batch_logs) if callback_model.stop_training: break @@ -611,7 +594,7 @@ def test_loop(model, inputs, targets, targets_batch, sample_weights=sample_weights_batch, training=False) - _, metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch) + metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch) batch_outs = [] for _, v in zip(model.metrics_names, [backend.mean(loss)] + loss_metrics + metrics_results): diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py index ed0f91ee1e2c6b..c45e07e08bcb51 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py @@ -212,7 +212,7 @@ def test_evaluate_predict_on_arrays(self): optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' loss_weights = [1., 0.5] - metrics = ['mae'] + metrics = ['acc', 'mae'] model.compile( optimizer, loss, @@ -231,20 +231,20 @@ def test_evaluate_predict_on_arrays(self): [input_a_np, input_b_np], [output_d_np, output_e_np], batch_size=5, verbose=0) - self.assertEqual(len(out), 5) + self.assertEqual(len(out), 7) out = model.evaluate( [input_a_np, input_b_np], [output_d_np, output_e_np], batch_size=5, verbose=1) - self.assertEqual(len(out), 5) + self.assertEqual(len(out), 7) out = model.evaluate( [input_a_np, input_b_np], [output_d_np, output_e_np], batch_size=5, verbose=2) - self.assertEqual(len(out), 5) + self.assertEqual(len(out), 7) out = model.test_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) - self.assertEqual(len(out), 5) + self.assertEqual(len(out), 7) # Test evaluate with dictionary inputs model.evaluate( @@ -625,7 +625,6 @@ def test_class_weight_invalid_use_case(self): bad_w_np = np.random.random((10, 2, 2)) model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - class CorrectnessTest(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes() diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index 08fd26dd18d5bc..47d80704cf6345 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -23,11 +23,14 @@ import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.keras._impl.keras.engine.training_utils import weighted_masked_objective from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer + try: import scipy.sparse as scipy_sparse # pylint: disable=g-import-not-at-top @@ -1667,6 +1670,29 @@ def test_model_custom_target_tensors(self): model.train_on_batch([input_a_np, input_b_np], [output_a_np, output_b_np]) + @tf_test_util.run_in_graph_and_eager_modes() + def test_metric_names_are_identical_in_graph_and_eager(self): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + loss_weights = [1., 0.5] + metrics = ['mae', 'acc'] + model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) + reference_metric_names = ['loss', 'dense_loss', 'dropout_loss', + 'dense_mean_absolute_error', + 'dense_acc', + 'dropout_mean_absolute_error', + 'dropout_acc'] + self.assertEqual(reference_metric_names, model.metrics_names) if __name__ == '__main__': # Bazel sets these environment variables to very long paths. diff --git a/tensorflow/python/keras/_impl/keras/engine/training_utils.py b/tensorflow/python/keras/_impl/keras/engine/training_utils.py index a3fc8ef2a0359c..34c0738f26fa74 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_utils.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_utils.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import losses +from tensorflow.python.keras._impl.keras import metrics as metrics_module from tensorflow.python.ops import math_ops @@ -553,3 +554,64 @@ def standardize_weights(y, def has_symbolic_tensors(ls): return (any(tensor_util.is_tensor(v) for v in ls) and not context.executing_eagerly()) + + +def populate_metric_names(model): + for i in range(len(model.outputs)): + metrics = model.nested_metrics[i] + for metric in metrics: + base_metric_name = get_base_metric_name(metric) + add_metric_name(model, base_metric_name, i) + + +def get_base_metric_name(metric, weighted=False): + """Returns the metric name given the metric function. + + Arguments: + metric: Metric function name or reference. + weighted: Boolean indicating if the metric for which we are adding + names is weighted. + + Returns: + a metric name. + """ + metric_name_prefix = 'weighted_' if weighted else '' + if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): + if metric in ('accuracy', 'acc'): + suffix = 'acc' + elif metric in ('crossentropy', 'ce'): + suffix = 'ce' + metric_name = metric_name_prefix + suffix + else: + metric_fn = metrics_module.get(metric) + # Get metric name as string + if hasattr(metric_fn, 'name'): + metric_name = metric_fn.name + else: + metric_name = metric_fn.__name__ + metric_name = metric_name_prefix + metric_name + + return metric_name + + +def add_metric_name(model, metric_name, index): + """Makes the metric name unique and adds it to the model's metric name list. + + If there are multiple outputs for which the metrics are calculated, the + metric names have to be made unique by appending an integer. + + Arguments: + model: Model to which we are adding metric names. + metric_name: Metric name that corresponds to the metric specified by the + user. For example: 'acc' + index: The index of the model output for which the metric name is being + added. + """ + if len(model.output_names) > 1: + metric_name = '%s_%s' % (model.output_names[index], metric_name) + j = 1 + base_metric_name = metric_name + while metric_name in model.metrics_names: + metric_name = '%s_%d' % (base_metric_name, j) + j += 1 + model.metrics_names.append(metric_name)