Skip to content

Commit

Permalink
ENH add coverage multilabel ranking metric
Browse files Browse the repository at this point in the history
  • Loading branch information
arjoly committed Nov 27, 2014
1 parent 84f1134 commit 81ac899
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ details.
:toctree: generated/
:template: function.rst

metrics.coverage_error
metrics.label_ranking_average_precision_score


Expand Down
32 changes: 31 additions & 1 deletion doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,36 @@ In multilabel learning, each sample can have any number of ground truth labels
associated with it. The goal is to give high scores and better rank to
the ground truth labels.

Coverage
--------

The :func:`coverage_error` function computes the average number of labels that
have to be included in the final prediction such such that all true labels
are predicted. This is usefull if you want to know how many top-scored-labels
you have to predict in average without missing any true one. The best and
minimal coverage is thus the average number of true labels.

Formally, given a binary indicator matrix of the ground truth labels
:math:`y \in \mathcal{R}^{n_\text{samples} \times n_\text{labels}}` and the
score associated with each label
:math:`\hat{f} \in \mathcal{R}^{n_\text{samples} \times n_\text{labels}}`,
the coverage is defined as

.. math::
coverage(y, \hat{f}) = \frac{1}{n_{\text{samples}}}
\sum_{i=0}^{n_{\text{samples}} - 1} \max_{j:y_{ij} = 1} rank_{ij}
with :math:`\text{rank}_{ij} = \left|\left\{k: \hat{f}_{ik} \geq \hat{f}_{ij} \right\}\right|`.

Here is a small example of usage of this function::

>>> import numpy as np
>>> from sklearn.metrics import coverage_error
>>> y_true = np.array([[1, 0, 0], [0, 0, 1]])
>>> y_score = np.array([[0.75, 0.5, 1], [1, 0.2, 0.1]])
>>> coverage_error(y_true, y_score)
2.5

Label ranking average precision
-------------------------------

Expand All @@ -986,7 +1016,7 @@ score. This metric will yield better scores if you are able to give better rank
to the labels associated with each sample. The obtained score is always strictly
greater than 0, and the best value is 1. If there is exactly one relevant
label per sample, label ranking average precision is equivalent to the `mean
reciprocal rank <http://en.wikipedia.org/wiki/Mean_reciprocal_rank>`.
reciprocal rank <http://en.wikipedia.org/wiki/Mean_reciprocal_rank>`_.

Formally, given a binary indicator matrix of the ground truth labels
:math:`y \in \mathcal{R}^{n_\text{samples} \times n_\text{labels}}` and the
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .ranking import auc
from .ranking import average_precision_score
from .ranking import coverage_error
from .ranking import label_ranking_average_precision_score
from .ranking import precision_recall_curve
from .ranking import roc_auc_score
Expand Down Expand Up @@ -67,6 +68,7 @@
'completeness_score',
'confusion_matrix',
'consensus_score',
'coverage_error',
'euclidean_distances',
'explained_variance_score',
'f1_score',
Expand Down
46 changes: 46 additions & 0 deletions sklearn/metrics/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,49 @@ def label_ranking_average_precision_score(y_true, y_score):
out += np.divide(L, rank, dtype=float).mean()

return out / n_samples


check_consistent_length(y_true, y_score)


def coverage_error(y_true, y_score):
""" Coverage error measure
Compute how fare we need to go through the ranking scores to get all
true labels. The best value is equal to the average the number
of labels in y_true per sample.
Parameters
----------
y_true : array, shape = [n_samples, n_labels]
True binary labels in binary indicator format.
y_score : array, shape = [n_samples, n_labels]
Target scores, can either be probability estimates of the positive
class, confidence values, or binary decisions.
sample_weight : array-like of shape = [n_samples], optional
Sample weights.
Return
------
coverage : float
"""
y_true = check_array(y_true, ensure_2d=False)
y_score = check_array(y_score, ensure_2d=False)
check_consistent_length(y_true, y_score)

y_type = type_of_target(y_true)
if y_type != "multilabel-indicator":
raise ValueError("{0} format is not supported".format(y_type))

if y_true.shape != y_score.shape:
raise ValueError("y_true and y_score have different shape")

y_score_mask = np.ma.masked_array(y_score, mask=np.logical_not(y_true))
y_min_relevant = y_score_mask.min(axis=1).reshape((-1, 1))
coverage = (y_score >= y_min_relevant).sum(axis=1)
coverage = coverage.filled(0)

return coverage.mean()
8 changes: 8 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import coverage_error
from sklearn.metrics import explained_variance_score
from sklearn.metrics import f1_score
from sklearn.metrics import fbeta_score
Expand Down Expand Up @@ -139,6 +140,8 @@
}

THRESHOLDED_METRICS = {
"coverage_error": coverage_error,

"log_loss": log_loss,
"unnormalized_log_loss": partial(log_loss, normalize=False),

Expand Down Expand Up @@ -191,6 +194,8 @@

"roc_auc_score", "micro_roc_auc", "weighted_roc_auc",
"macro_roc_auc", "samples_roc_auc",

"coverage_error",
]

# Metrics with an "average" argument
Expand Down Expand Up @@ -255,6 +260,8 @@
"average_precision_score", "weighted_average_precision_score",
"samples_average_precision_score", "micro_average_precision_score",
"macro_average_precision_score",

"coverage_error",
]

# Classification metrics with "multilabel-indicator" and
Expand Down Expand Up @@ -326,6 +333,7 @@
"hamming_loss",
"matthews_corrcoef_score",
"median_absolute_error",
"coverage_error",
]


Expand Down
82 changes: 81 additions & 1 deletion sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
from sklearn.metrics import auc
from sklearn.metrics import auc_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import coverage_error
from sklearn.metrics import label_ranking_average_precision_score
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve

from sklearn.metrics.base import UndefinedMetricWarning

Expand Down Expand Up @@ -844,3 +845,82 @@ def test_label_ranking_avp():
yield (check_alternative_lrap_implementation,
label_ranking_average_precision_score,
n_classes, n_samples, random_state)


def test_coverage_error():
# Toy case
assert_almost_equal(coverage_error([[0, 1]], [[0.25, 0.75]]), 1)
assert_almost_equal(coverage_error([[0, 1]], [[0.75, 0.25]]), 2)
assert_almost_equal(coverage_error([[1, 1]], [[0.75, 0.25]]), 2)
assert_almost_equal(coverage_error([[0, 0]], [[0.75, 0.25]]), 0)

assert_almost_equal(coverage_error([[0, 0, 0]], [[0.25, 0.5, 0.75]]), 0)
assert_almost_equal(coverage_error([[0, 0, 1]], [[0.25, 0.5, 0.75]]), 1)
assert_almost_equal(coverage_error([[0, 1, 0]], [[0.25, 0.5, 0.75]]), 2)
assert_almost_equal(coverage_error([[0, 1, 1]], [[0.25, 0.5, 0.75]]), 2)
assert_almost_equal(coverage_error([[1, 0, 0]], [[0.25, 0.5, 0.75]]), 3)
assert_almost_equal(coverage_error([[1, 0, 1]], [[0.25, 0.5, 0.75]]), 3)
assert_almost_equal(coverage_error([[1, 1, 0]], [[0.25, 0.5, 0.75]]), 3)
assert_almost_equal(coverage_error([[1, 1, 1]], [[0.25, 0.5, 0.75]]), 3)

assert_almost_equal(coverage_error([[0, 0, 0]], [[0.75, 0.5, 0.25]]), 0)
assert_almost_equal(coverage_error([[0, 0, 1]], [[0.75, 0.5, 0.25]]), 3)
assert_almost_equal(coverage_error([[0, 1, 0]], [[0.75, 0.5, 0.25]]), 2)
assert_almost_equal(coverage_error([[0, 1, 1]], [[0.75, 0.5, 0.25]]), 3)
assert_almost_equal(coverage_error([[1, 0, 0]], [[0.75, 0.5, 0.25]]), 1)
assert_almost_equal(coverage_error([[1, 0, 1]], [[0.75, 0.5, 0.25]]), 3)
assert_almost_equal(coverage_error([[1, 1, 0]], [[0.75, 0.5, 0.25]]), 2)
assert_almost_equal(coverage_error([[1, 1, 1]], [[0.75, 0.5, 0.25]]), 3)

assert_almost_equal(coverage_error([[0, 0, 0]], [[0.5, 0.75, 0.25]]), 0)
assert_almost_equal(coverage_error([[0, 0, 1]], [[0.5, 0.75, 0.25]]), 3)
assert_almost_equal(coverage_error([[0, 1, 0]], [[0.5, 0.75, 0.25]]), 1)
assert_almost_equal(coverage_error([[0, 1, 1]], [[0.5, 0.75, 0.25]]), 3)
assert_almost_equal(coverage_error([[1, 0, 0]], [[0.5, 0.75, 0.25]]), 2)
assert_almost_equal(coverage_error([[1, 0, 1]], [[0.5, 0.75, 0.25]]), 3)
assert_almost_equal(coverage_error([[1, 1, 0]], [[0.5, 0.75, 0.25]]), 2)
assert_almost_equal(coverage_error([[1, 1, 1]], [[0.5, 0.75, 0.25]]), 3)

# Tie handling
assert_almost_equal(coverage_error([[0, 0]], [[0.5, 0.5]]), 0)
assert_almost_equal(coverage_error([[1, 0]], [[0.5, 0.5]]), 2)
assert_almost_equal(coverage_error([[0, 1]], [[0.5, 0.5]]), 2)
assert_almost_equal(coverage_error([[1, 1]], [[0.5, 0.5]]), 2)

assert_almost_equal(coverage_error([[0, 0, 0]], [[0.25, 0.5, 0.5]]), 0)
assert_almost_equal(coverage_error([[0, 0, 1]], [[0.25, 0.5, 0.5]]), 2)
assert_almost_equal(coverage_error([[0, 1, 0]], [[0.25, 0.5, 0.5]]), 2)
assert_almost_equal(coverage_error([[0, 1, 1]], [[0.25, 0.5, 0.5]]), 2)
assert_almost_equal(coverage_error([[1, 0, 0]], [[0.25, 0.5, 0.5]]), 3)
assert_almost_equal(coverage_error([[1, 0, 1]], [[0.25, 0.5, 0.5]]), 3)
assert_almost_equal(coverage_error([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 3)
assert_almost_equal(coverage_error([[1, 1, 1]], [[0.25, 0.5, 0.5]]), 3)

# Non trival case
assert_almost_equal(coverage_error([[0, 1, 0], [1, 1, 0]],
[[0.1, 10., -3], [0, 1, 3]]),
(1 + 3) / 2.)

assert_almost_equal(coverage_error([[0, 1, 0], [1, 1, 0], [0, 1, 1]],
[[0.1, 10, -3], [0, 1, 3], [0, 2, 0]]),
(1 + 3 + 3) / 3.)

assert_almost_equal(coverage_error([[0, 1, 0], [1, 1, 0], [0, 1, 1]],
[[0.1, 10, -3], [3, 1, 3], [0, 2, 0]]),
(1 + 3 + 3) / 3.)

# Raise value error if not appropriate format
assert_raises(ValueError, coverage_error, [0, 1, 0], [0.25, 0.3, 0.2])
assert_raises(ValueError, coverage_error, [0, 1, 2],
[[0.25, 0.75, 0.0], [0.7, 0.3, 0.0], [0.8, 0.2, 0.0]])
assert_raises(ValueError, coverage_error, [(0), (1), (2)],
[[0.25, 0.75, 0.0], [0.7, 0.3, 0.0], [0.8, 0.2, 0.0]])

# Check that that y_true.shape != y_score.shape raise the proper exception
assert_raises(ValueError, coverage_error, [[0, 1], [0, 1]], [0 , 1])
assert_raises(ValueError, coverage_error, [[0, 1], [0, 1]], [[0 , 1]])
assert_raises(ValueError, coverage_error, [[0, 1], [0, 1]], [[0] , [1]])

assert_raises(ValueError, coverage_error, [[0, 1]], [[0 , 1], [0, 1]])
assert_raises(ValueError, coverage_error, [[0], [1]], [[0 , 1], [0, 1]])
assert_raises(ValueError, coverage_error, [[0, 1], [0, 1]], [[0] , [1]])

0 comments on commit 81ac899

Please sign in to comment.