Skip to content

Commit

Permalink
Allow metrics 'crossentropy', 'ce' (keras-team#8864)
Browse files Browse the repository at this point in the history
* Allow metrics 'crossentropy', 'ce'

* append correct suffix
  • Loading branch information
ozabluda authored and fchollet committed Jan 12, 2018
1 parent 616a9b0 commit 76c5b61
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,22 +871,35 @@ def handle_metrics(metrics, weights=None):
metric_name_prefix = 'weighted_' if weights is not None else ''

for metric in metrics:
if metric == 'accuracy' or metric == 'acc':
# custom handling of accuracy
if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
# custom handling of accuracy/crossentropy
# (because of class mode duality)
output_shape = self._internal_output_shapes[i]
if (output_shape[-1] == 1 or
self.loss_functions[i] == losses.binary_crossentropy):
# case: binary accuracy
acc_fn = metrics_module.binary_accuracy
# case: binary accuracy/crossentropy
if metric in ('accuracy', 'acc'):
acc_fn = metrics_module.binary_accuracy
elif metric in ('crossentropy', 'ce'):
acc_fn = metrics_module.binary_crossentropy
elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
# case: categorical accuracy with sparse targets
acc_fn = metrics_module.sparse_categorical_accuracy
# case: categorical accuracy/crossentropy with sparse targets
if metric in ('accuracy', 'acc'):
acc_fn = metrics_module.sparse_categorical_accuracy
elif metric in ('crossentropy', 'ce'):
acc_fn = metrics_module.sparse_categorical_crossentropy
else:
acc_fn = metrics_module.categorical_accuracy

# case: categorical accuracy/crossentropy
if metric in ('accuracy', 'acc'):
acc_fn = metrics_module.categorical_accuracy
elif metric in ('crossentropy', 'ce'):
acc_fn = metrics_module.categorical_crossentropy
if metric in ('accuracy', 'acc'):
suffix = 'acc'
elif metric in ('crossentropy', 'ce'):
suffix = 'ce'
weighted_metric_fn = _weighted_masked_objective(acc_fn)
metric_name = metric_name_prefix + 'acc'
metric_name = metric_name_prefix + suffix
else:
metric_fn = metrics_module.get(metric)
weighted_metric_fn = _weighted_masked_objective(metric_fn)
Expand Down

0 comments on commit 76c5b61

Please sign in to comment.