From 81ac899714841c335291aa88f474a4d8164b1e94 Mon Sep 17 00:00:00 2001 From: Arnaud Joly Date: Fri, 21 Nov 2014 13:52:21 +0100 Subject: [PATCH] ENH add coverage multilabel ranking metric --- doc/modules/classes.rst | 1 + doc/modules/model_evaluation.rst | 32 ++++++++++- sklearn/metrics/__init__.py | 2 + sklearn/metrics/ranking.py | 46 +++++++++++++++ sklearn/metrics/tests/test_common.py | 8 +++ sklearn/metrics/tests/test_ranking.py | 82 ++++++++++++++++++++++++++- 6 files changed, 169 insertions(+), 2 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 215cae643550a..fc225f9ed22fa 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -787,6 +787,7 @@ details. :toctree: generated/ :template: function.rst + metrics.coverage_error metrics.label_ranking_average_precision_score diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 9f6a7431cf253..d42bc6ab0a3e3 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -972,6 +972,36 @@ In multilabel learning, each sample can have any number of ground truth labels associated with it. The goal is to give high scores and better rank to the ground truth labels. +Coverage +-------- + +The :func:`coverage_error` function computes the average number of labels that +have to be included in the final prediction such such that all true labels +are predicted. This is usefull if you want to know how many top-scored-labels +you have to predict in average without missing any true one. The best and +minimal coverage is thus the average number of true labels. + +Formally, given a binary indicator matrix of the ground truth labels +:math:`y \in \mathcal{R}^{n_\text{samples} \times n_\text{labels}}` and the +score associated with each label +:math:`\hat{f} \in \mathcal{R}^{n_\text{samples} \times n_\text{labels}}`, +the coverage is defined as + +.. math:: + coverage(y, \hat{f}) = \frac{1}{n_{\text{samples}}} + \sum_{i=0}^{n_{\text{samples}} - 1} \max_{j:y_{ij} = 1} rank_{ij} + +with :math:`\text{rank}_{ij} = \left|\left\{k: \hat{f}_{ik} \geq \hat{f}_{ij} \right\}\right|`. + +Here is a small example of usage of this function:: + + >>> import numpy as np + >>> from sklearn.metrics import coverage_error + >>> y_true = np.array([[1, 0, 0], [0, 0, 1]]) + >>> y_score = np.array([[0.75, 0.5, 1], [1, 0.2, 0.1]]) + >>> coverage_error(y_true, y_score) + 2.5 + Label ranking average precision ------------------------------- @@ -986,7 +1016,7 @@ score. This metric will yield better scores if you are able to give better rank to the labels associated with each sample. The obtained score is always strictly greater than 0, and the best value is 1. If there is exactly one relevant label per sample, label ranking average precision is equivalent to the `mean -reciprocal rank `. +reciprocal rank `_. Formally, given a binary indicator matrix of the ground truth labels :math:`y \in \mathcal{R}^{n_\text{samples} \times n_\text{labels}}` and the diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 16ebd318d6a37..e419fec7a1c91 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -5,6 +5,7 @@ from .ranking import auc from .ranking import average_precision_score +from .ranking import coverage_error from .ranking import label_ranking_average_precision_score from .ranking import precision_recall_curve from .ranking import roc_auc_score @@ -67,6 +68,7 @@ 'completeness_score', 'confusion_matrix', 'consensus_score', + 'coverage_error', 'euclidean_distances', 'explained_variance_score', 'f1_score', diff --git a/sklearn/metrics/ranking.py b/sklearn/metrics/ranking.py index 47efce7520137..c2affafd09550 100644 --- a/sklearn/metrics/ranking.py +++ b/sklearn/metrics/ranking.py @@ -619,3 +619,49 @@ def label_ranking_average_precision_score(y_true, y_score): out += np.divide(L, rank, dtype=float).mean() return out / n_samples + + + check_consistent_length(y_true, y_score) + + +def coverage_error(y_true, y_score): + """ Coverage error measure + + Compute how fare we need to go through the ranking scores to get all + true labels. The best value is equal to the average the number + of labels in y_true per sample. + + Parameters + ---------- + y_true : array, shape = [n_samples, n_labels] + True binary labels in binary indicator format. + + y_score : array, shape = [n_samples, n_labels] + Target scores, can either be probability estimates of the positive + class, confidence values, or binary decisions. + + sample_weight : array-like of shape = [n_samples], optional + Sample weights. + + Return + ------ + coverage : float + + """ + y_true = check_array(y_true, ensure_2d=False) + y_score = check_array(y_score, ensure_2d=False) + check_consistent_length(y_true, y_score) + + y_type = type_of_target(y_true) + if y_type != "multilabel-indicator": + raise ValueError("{0} format is not supported".format(y_type)) + + if y_true.shape != y_score.shape: + raise ValueError("y_true and y_score have different shape") + + y_score_mask = np.ma.masked_array(y_score, mask=np.logical_not(y_true)) + y_min_relevant = y_score_mask.min(axis=1).reshape((-1, 1)) + coverage = (y_score >= y_min_relevant).sum(axis=1) + coverage = coverage.filled(0) + + return coverage.mean() diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 9f233de490c1c..09e3bf999adc9 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -26,6 +26,7 @@ from sklearn.metrics import accuracy_score from sklearn.metrics import average_precision_score from sklearn.metrics import confusion_matrix +from sklearn.metrics import coverage_error from sklearn.metrics import explained_variance_score from sklearn.metrics import f1_score from sklearn.metrics import fbeta_score @@ -139,6 +140,8 @@ } THRESHOLDED_METRICS = { + "coverage_error": coverage_error, + "log_loss": log_loss, "unnormalized_log_loss": partial(log_loss, normalize=False), @@ -191,6 +194,8 @@ "roc_auc_score", "micro_roc_auc", "weighted_roc_auc", "macro_roc_auc", "samples_roc_auc", + + "coverage_error", ] # Metrics with an "average" argument @@ -255,6 +260,8 @@ "average_precision_score", "weighted_average_precision_score", "samples_average_precision_score", "micro_average_precision_score", "macro_average_precision_score", + + "coverage_error", ] # Classification metrics with "multilabel-indicator" and @@ -326,6 +333,7 @@ "hamming_loss", "matthews_corrcoef_score", "median_absolute_error", + "coverage_error", ] diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 0cd6519ac8880..47ba620902ba6 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -25,10 +25,11 @@ from sklearn.metrics import auc from sklearn.metrics import auc_score from sklearn.metrics import average_precision_score +from sklearn.metrics import coverage_error from sklearn.metrics import label_ranking_average_precision_score -from sklearn.metrics import roc_curve from sklearn.metrics import precision_recall_curve from sklearn.metrics import roc_auc_score +from sklearn.metrics import roc_curve from sklearn.metrics.base import UndefinedMetricWarning @@ -844,3 +845,82 @@ def test_label_ranking_avp(): yield (check_alternative_lrap_implementation, label_ranking_average_precision_score, n_classes, n_samples, random_state) + + +def test_coverage_error(): + # Toy case + assert_almost_equal(coverage_error([[0, 1]], [[0.25, 0.75]]), 1) + assert_almost_equal(coverage_error([[0, 1]], [[0.75, 0.25]]), 2) + assert_almost_equal(coverage_error([[1, 1]], [[0.75, 0.25]]), 2) + assert_almost_equal(coverage_error([[0, 0]], [[0.75, 0.25]]), 0) + + assert_almost_equal(coverage_error([[0, 0, 0]], [[0.25, 0.5, 0.75]]), 0) + assert_almost_equal(coverage_error([[0, 0, 1]], [[0.25, 0.5, 0.75]]), 1) + assert_almost_equal(coverage_error([[0, 1, 0]], [[0.25, 0.5, 0.75]]), 2) + assert_almost_equal(coverage_error([[0, 1, 1]], [[0.25, 0.5, 0.75]]), 2) + assert_almost_equal(coverage_error([[1, 0, 0]], [[0.25, 0.5, 0.75]]), 3) + assert_almost_equal(coverage_error([[1, 0, 1]], [[0.25, 0.5, 0.75]]), 3) + assert_almost_equal(coverage_error([[1, 1, 0]], [[0.25, 0.5, 0.75]]), 3) + assert_almost_equal(coverage_error([[1, 1, 1]], [[0.25, 0.5, 0.75]]), 3) + + assert_almost_equal(coverage_error([[0, 0, 0]], [[0.75, 0.5, 0.25]]), 0) + assert_almost_equal(coverage_error([[0, 0, 1]], [[0.75, 0.5, 0.25]]), 3) + assert_almost_equal(coverage_error([[0, 1, 0]], [[0.75, 0.5, 0.25]]), 2) + assert_almost_equal(coverage_error([[0, 1, 1]], [[0.75, 0.5, 0.25]]), 3) + assert_almost_equal(coverage_error([[1, 0, 0]], [[0.75, 0.5, 0.25]]), 1) + assert_almost_equal(coverage_error([[1, 0, 1]], [[0.75, 0.5, 0.25]]), 3) + assert_almost_equal(coverage_error([[1, 1, 0]], [[0.75, 0.5, 0.25]]), 2) + assert_almost_equal(coverage_error([[1, 1, 1]], [[0.75, 0.5, 0.25]]), 3) + + assert_almost_equal(coverage_error([[0, 0, 0]], [[0.5, 0.75, 0.25]]), 0) + assert_almost_equal(coverage_error([[0, 0, 1]], [[0.5, 0.75, 0.25]]), 3) + assert_almost_equal(coverage_error([[0, 1, 0]], [[0.5, 0.75, 0.25]]), 1) + assert_almost_equal(coverage_error([[0, 1, 1]], [[0.5, 0.75, 0.25]]), 3) + assert_almost_equal(coverage_error([[1, 0, 0]], [[0.5, 0.75, 0.25]]), 2) + assert_almost_equal(coverage_error([[1, 0, 1]], [[0.5, 0.75, 0.25]]), 3) + assert_almost_equal(coverage_error([[1, 1, 0]], [[0.5, 0.75, 0.25]]), 2) + assert_almost_equal(coverage_error([[1, 1, 1]], [[0.5, 0.75, 0.25]]), 3) + + # Tie handling + assert_almost_equal(coverage_error([[0, 0]], [[0.5, 0.5]]), 0) + assert_almost_equal(coverage_error([[1, 0]], [[0.5, 0.5]]), 2) + assert_almost_equal(coverage_error([[0, 1]], [[0.5, 0.5]]), 2) + assert_almost_equal(coverage_error([[1, 1]], [[0.5, 0.5]]), 2) + + assert_almost_equal(coverage_error([[0, 0, 0]], [[0.25, 0.5, 0.5]]), 0) + assert_almost_equal(coverage_error([[0, 0, 1]], [[0.25, 0.5, 0.5]]), 2) + assert_almost_equal(coverage_error([[0, 1, 0]], [[0.25, 0.5, 0.5]]), 2) + assert_almost_equal(coverage_error([[0, 1, 1]], [[0.25, 0.5, 0.5]]), 2) + assert_almost_equal(coverage_error([[1, 0, 0]], [[0.25, 0.5, 0.5]]), 3) + assert_almost_equal(coverage_error([[1, 0, 1]], [[0.25, 0.5, 0.5]]), 3) + assert_almost_equal(coverage_error([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 3) + assert_almost_equal(coverage_error([[1, 1, 1]], [[0.25, 0.5, 0.5]]), 3) + + # Non trival case + assert_almost_equal(coverage_error([[0, 1, 0], [1, 1, 0]], + [[0.1, 10., -3], [0, 1, 3]]), + (1 + 3) / 2.) + + assert_almost_equal(coverage_error([[0, 1, 0], [1, 1, 0], [0, 1, 1]], + [[0.1, 10, -3], [0, 1, 3], [0, 2, 0]]), + (1 + 3 + 3) / 3.) + + assert_almost_equal(coverage_error([[0, 1, 0], [1, 1, 0], [0, 1, 1]], + [[0.1, 10, -3], [3, 1, 3], [0, 2, 0]]), + (1 + 3 + 3) / 3.) + + # Raise value error if not appropriate format + assert_raises(ValueError, coverage_error, [0, 1, 0], [0.25, 0.3, 0.2]) + assert_raises(ValueError, coverage_error, [0, 1, 2], + [[0.25, 0.75, 0.0], [0.7, 0.3, 0.0], [0.8, 0.2, 0.0]]) + assert_raises(ValueError, coverage_error, [(0), (1), (2)], + [[0.25, 0.75, 0.0], [0.7, 0.3, 0.0], [0.8, 0.2, 0.0]]) + + # Check that that y_true.shape != y_score.shape raise the proper exception + assert_raises(ValueError, coverage_error, [[0, 1], [0, 1]], [0 , 1]) + assert_raises(ValueError, coverage_error, [[0, 1], [0, 1]], [[0 , 1]]) + assert_raises(ValueError, coverage_error, [[0, 1], [0, 1]], [[0] , [1]]) + + assert_raises(ValueError, coverage_error, [[0, 1]], [[0 , 1], [0, 1]]) + assert_raises(ValueError, coverage_error, [[0], [1]], [[0 , 1], [0, 1]]) + assert_raises(ValueError, coverage_error, [[0, 1], [0, 1]], [[0] , [1]])