Skip to content

Commit

Permalink
add a test for focal loss with label smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
Ely-S committed Jun 20, 2020
1 parent 7ab026b commit e1fb1ed
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
6 changes: 3 additions & 3 deletions efficientdet/det_model_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def learning_rate_schedule(params, global_step):
raise ValueError('unknown lr_decay_method: {}'.format(lr_decay_method))


def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _):
"""A legacy focal loss that does not support label smooth."""
def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _=0):
"""A legacy focal loss that does not support label smoothing."""
with tf.name_scope('focal_loss'):
positive_label_mask = tf.equal(targets, 1.0)
cross_entropy = (
Expand All @@ -159,7 +159,7 @@ def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _):
return weighted_loss


def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing):
def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing=0):
"""Compute the focal loss between `logits` and the golden `target` values.
Focal loss = -(1-pt)^gamma * log(pt)
Expand Down
22 changes: 18 additions & 4 deletions efficientdet/det_model_fn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,32 @@
import det_model_fn


class DetModelFnTest(tf.test.TestCase):
class FocalLossTest(tf.test.TestCase):

def test_focal_loss(self):
tf.random.set_seed(111)
y_pred = tf.random.uniform([4, 32, 32, 90])
y_true = tf.ones([4, 32, 32, 90])
alpha, gamma, n = 0.25, 1.5, 100
legacy_output = det_model_fn.legacy_focal_loss(y_pred, y_true, alpha, gamma,
n, 0)
new_output = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n, 0)
n)
new_output = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n)
self.assertAllClose(legacy_output, new_output)


def test_focal_loss_with_label_smoothing(self):
y_pred = tf.random.uniform([4, 32, 32, 2])

# A binary classification target [0.0, 1.0] becomes [.1, .9]
# with smoothing .2
y_true = tf.ones([4, 32, 32, 2]) * [0.0, 1.0]
y_true_presmoothed = tf.ones([4, 32, 32, 2]) * [0.1, 0.9]

alpha, gamma, n = 0.25, 1.5, 100

presmoothed = det_model_fn.focal_loss(y_pred, y_true_presmoothed, alpha, gamma, n, 0)
unsmoothed = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n, 0.2)

self.assertAllClose(presmoothed, unsmoothed)

if __name__ == '__main__':
tf.test.main()

0 comments on commit e1fb1ed

Please sign in to comment.