Skip to content

Commit 52ee238

Browse files
ramananbalakrishnanfchollet
authored andcommitted
Add top-k classification accuracy metrics (keras-team#3987)
* add categorical accuracy metric which tracks over top-k predictions * remove top_k_categorical_accuracy from being tested together with other all_metrics * fix in_top_k to work with batches. correct metrics.py and test_metrics.py appropriately * style fixes for documentation on in_top_k function * default to k=5 for top_k_categorical_accuracy metric
1 parent 530eff6 commit 52ee238

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

keras/backend/tensorflow_backend.py

+14
Original file line numberDiff line numberDiff line change
@@ -1485,6 +1485,20 @@ def l2_normalize(x, axis):
14851485
axis = axis % len(x.get_shape())
14861486
return tf.nn.l2_normalize(x, dim=axis)
14871487

1488+
def in_top_k(predictions, targets, k):
1489+
'''Says whether the `targets` are in the top `k` `predictions`
1490+
1491+
# Arguments
1492+
predictions: A tensor of shape batch_size x classess and type float32.
1493+
targets: A tensor of shape batch_size and type int32 or int64.
1494+
k: An int, number of top elements to consider.
1495+
1496+
# Returns
1497+
A tensor of shape batch_size and type bool. output_i is True if
1498+
targets_i is within top-k values of predictions_i
1499+
'''
1500+
return tf.nn.in_top_k(predictions, targets, k)
1501+
14881502

14891503
# CONVOLUTIONS
14901504

keras/backend/theano_backend.py

+17
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,23 @@ def l2_normalize(x, axis):
10431043
return x / norm
10441044

10451045

1046+
def in_top_k(predictions, targets, k):
1047+
'''Says whether the `targets` are in the top `k` `predictions`
1048+
1049+
# Arguments
1050+
predictions: A tensor of shape batch_size x classess and type float32.
1051+
targets: A tensor of shape batch_size and type int32 or int64.
1052+
k: An int, number of top elements to consider.
1053+
1054+
# Returns
1055+
A tensor of shape batch_size and type int. output_i is 1 if
1056+
targets_i is within top-k values of predictions_i
1057+
'''
1058+
predictions_top_k = T.argsort(predictions)[:, -k:]
1059+
result, _ = theano.map(lambda prediction, target: any(equal(prediction, target)), sequences=[predictions_top_k, targets])
1060+
return result
1061+
1062+
10461063
# CONVOLUTIONS
10471064

10481065
def _preprocess_conv2d_input(x, dim_ordering):

keras/metrics.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ def sparse_categorical_accuracy(y_true, y_pred):
1717
K.cast(K.argmax(y_pred, axis=-1), K.floatx())))
1818

1919

20+
def top_k_categorical_accuracy(y_true, y_pred, k=5):
21+
return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k))
22+
23+
2024
def mean_squared_error(y_true, y_pred):
2125
return K.mean(K.square(y_pred - y_true))
2226

tests/keras/test_metrics.py

+14
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,19 @@ def test_sparse_metrics():
6565
assert K.eval(metric(y_a, y_b)).shape == ()
6666

6767

68+
def test_top_k_categorical_accuracy():
69+
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
70+
y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]]))
71+
success_result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred,
72+
k=3))
73+
assert success_result == 1
74+
partial_result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred,
75+
k=2))
76+
assert partial_result == 0.5
77+
failure_result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred,
78+
k=1))
79+
assert failure_result == 0
80+
81+
6882
if __name__ == "__main__":
6983
pytest.main([__file__])

0 commit comments

Comments
 (0)