Skip to content

Commit

Permalink
[DGL-LifeSci] Handle Same Labels for ROC AUC (dmlc#1581)
Browse files Browse the repository at this point in the history
* Handle corner case for ROC AUC

* Update

* Update doct
  • Loading branch information
mufeili authored Jun 2, 2020
1 parent a304df5 commit b8dffcd
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
14 changes: 12 additions & 2 deletions apps/life_sci/python/dgllife/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def multilabel_score(self, score_func, reduction='none'):
task_w = mask[:, task]
task_y_true = y_true[:, task][task_w != 0]
task_y_pred = y_pred[:, task][task_w != 0]
scores.append(score_func(task_y_true, task_y_pred))
task_score = score_func(task_y_true, task_y_pred)
if task_score is not None:
scores.append(task_score)
return self._reduce_scores(scores, reduction)

def pearson_r2(self, reduction='none'):
Expand Down Expand Up @@ -236,6 +238,9 @@ def score(y_true, y_pred):
def roc_auc_score(self, reduction='none'):
"""Compute roc-auc score for binary classification.
ROC-AUC scores are not well-defined in cases where labels for a task have one single
class only. In this case we will simply ignore this task and print a warning message.
Parameters
----------
reduction : 'none' or 'mean' or 'sum'
Expand All @@ -253,7 +258,12 @@ def roc_auc_score(self, reduction='none'):
assert (self.mean is None) and (self.std is None), \
'Label normalization should not be performed for binary classification.'
def score(y_true, y_pred):
return roc_auc_score(y_true.long().numpy(), torch.sigmoid(y_pred).numpy())
if len(y_true.unique()) == 1:
print('Warning: Only one class {} present in y_true for a task. '
'ROC AUC score is not defined in that case.'.format(y_true[0]))
return None
else:
return roc_auc_score(y_true.long().numpy(), torch.sigmoid(y_pred).numpy())
return self.multilabel_score(score, reduction)

def compute_metric(self, metric_name, reduction='none'):
Expand Down
41 changes: 28 additions & 13 deletions apps/life_sci/tests/utils/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_Meter():

# pearson r2
meter = Meter(label_mean, label_std)
meter.update(label, pred)
true_scores = [0.7499999999999999, 0.7499999999999999]
meter.update(pred, label)
true_scores = [0.7500000774286983, 0.7500000516191412]
assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores)
assert meter.pearson_r2('sum') == np.sum(true_scores)
Expand All @@ -27,7 +27,7 @@ def test_Meter():
assert meter.compute_metric('r2', 'sum') == np.sum(true_scores)

meter = Meter(label_mean, label_std)
meter.update(label, pred, mask)
meter.update(pred, label, mask)
true_scores = [1.0, 1.0]
assert meter.pearson_r2() == true_scores
assert meter.pearson_r2('mean') == np.mean(true_scores)
Expand All @@ -38,7 +38,7 @@ def test_Meter():

# mae
meter = Meter()
meter.update(label, pred)
meter.update(pred, label)
true_scores = [0.1666666716337204, 0.1666666716337204]
assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores)
Expand All @@ -48,7 +48,7 @@ def test_Meter():
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)

meter = Meter()
meter.update(label, pred, mask)
meter.update(pred, label, mask)
true_scores = [0.25, 0.0]
assert meter.mae() == true_scores
assert meter.mae('mean') == np.mean(true_scores)
Expand All @@ -57,23 +57,23 @@ def test_Meter():
assert meter.compute_metric('mae', 'mean') == np.mean(true_scores)
assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)

# rmse
# rmsef
meter = Meter(label_mean, label_std)
meter.update(label, pred)
true_scores = [0.22125875529784111, 0.5937311018897714]
meter.update(pred, label)
true_scores = [0.41068359261794546, 0.4106836107598449]
assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))

meter = Meter(label_mean, label_std)
meter.update(label, pred, mask)
true_scores = [0.1337071188699867, 0.5019903799993205]
meter.update(pred, label, mask)
true_scores = [0.44433766459035057, 0.5019903799993205]
assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))

# roc auc score
meter = Meter()
meter.update(label, pred)
true_scores = [1.0, 0.75]
meter.update(pred, label)
true_scores = [1.0, 1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)
Expand All @@ -82,7 +82,7 @@ def test_Meter():
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)

meter = Meter()
meter.update(label, pred, mask)
meter.update(pred, label, mask)
true_scores = [1.0, 1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
Expand All @@ -91,5 +91,20 @@ def test_Meter():
assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores)
assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)

def test_cases_with_undefined_scores():
label = torch.tensor([[0., 1.],
[0., 1.],
[1., 1.]])
pred = torch.tensor([[0.5, 0.5],
[0., 1.],
[1., 0.]])
meter = Meter()
meter.update(pred, label)
true_scores = [1.0]
assert meter.roc_auc_score() == true_scores
assert meter.roc_auc_score('mean') == np.mean(true_scores)
assert meter.roc_auc_score('sum') == np.sum(true_scores)

if __name__ == '__main__':
test_Meter()
test_cases_with_undefined_scores()

0 comments on commit b8dffcd

Please sign in to comment.