Skip to content

Commit

Permalink
Compute focal loss multipliers before label smooth.
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxingtan committed Jun 22, 2020
1 parent d1f1f07 commit d32ccd0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 47 deletions.
31 changes: 7 additions & 24 deletions efficientdet/det_model_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,7 @@ def learning_rate_schedule(params, global_step):
raise ValueError('unknown lr_decay_method: {}'.format(lr_decay_method))


def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _=0):
"""A legacy focal loss that does not support label smoothing."""
with tf.name_scope('focal_loss'):
positive_label_mask = tf.equal(targets, 1.0)
cross_entropy = (
tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))

neg_logits = -1.0 * logits
modulator = tf.exp(gamma * targets * neg_logits -
gamma * tf.log1p(tf.exp(neg_logits)))
loss = modulator * cross_entropy
weighted_loss = tf.where(positive_label_mask, alpha * loss,
(1.0 - alpha) * loss)
weighted_loss /= normalizer
return weighted_loss


def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing=0):
def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing=0.0):
"""Compute the focal loss between `logits` and the golden `target` values.
Focal loss = -(1-pt)^gamma * log(pt)
Expand All @@ -183,17 +166,17 @@ def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing=0):
alpha = tf.convert_to_tensor(alpha, dtype=y_pred.dtype)
gamma = tf.convert_to_tensor(gamma, dtype=y_pred.dtype)

# apply label smoothing.
y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

# get cross_entropy for each entry.
ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred)

# compute focal loss multipliers before label smoothing, such that it will
# not blow up the loss.
pred_prob = tf.sigmoid(y_pred)
p_t = (y_true * pred_prob) + ((1 - y_true) * (1 - pred_prob))
alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
modulating_factor = (1.0 - p_t) ** gamma

# apply label smoothing for cross_entropy for each entry.
y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred)

# compute the final loss and return
return alpha_factor * modulating_factor * ce / normalizer

Expand Down
63 changes: 40 additions & 23 deletions efficientdet/det_model_fn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,55 @@
# ==============================================================================
"""Tests for det_model_fn."""
import tensorflow as tf

import det_model_fn


class FocalLossTest(tf.test.TestCase):
def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _=0):
"""A legacy focal loss that does not support label smoothing."""
with tf.name_scope('focal_loss'):
positive_label_mask = tf.equal(targets, 1.0)
cross_entropy = (
tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))

neg_logits = -1.0 * logits
modulator = tf.exp(gamma * targets * neg_logits -
gamma * tf.math.log1p(tf.exp(neg_logits)))
loss = modulator * cross_entropy
weighted_loss = tf.where(positive_label_mask, alpha * loss,
(1.0 - alpha) * loss)
weighted_loss /= normalizer
return weighted_loss

def test_focal_loss(self):
y_pred = tf.random.uniform([4, 32, 32, 90])
y_true = tf.ones([4, 32, 32, 90])
alpha, gamma, n = 0.25, 1.5, 100
legacy_output = det_model_fn.legacy_focal_loss(y_pred, y_true, alpha,
gamma, n)
new_output = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n)
self.assertAllClose(legacy_output, new_output)

def test_focal_loss_with_label_smoothing(self):
y_pred = tf.random.uniform([4, 32, 32, 2])
class FocalLossTest(tf.test.TestCase):

def test_focal_loss(self):
tf.random.set_seed(1111)
y_pred = tf.random.uniform([4, 32, 32, 90])
y_true = tf.ones([4, 32, 32, 90])
alpha, gamma, n = 0.25, 1.5, 100
legacy_output = legacy_focal_loss(y_pred, y_true, alpha, gamma, n)
new_output = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n)
self.assertAllClose(legacy_output, new_output)

# A binary classification target [0.0, 1.0] becomes [.1, .9]
# with smoothing .2
y_true = tf.ones([4, 32, 32, 2]) * [0.0, 1.0]
y_true_presmoothed = tf.ones([4, 32, 32, 2]) * [0.1, 0.9]
def test_focal_loss_with_label_smoothing(self):
tf.random.set_seed(1111)
shape = [2, 2, 2, 2]
y_pred = tf.random.uniform(shape)

alpha, gamma, n = 0.25, 1.5, 100
# A binary classification target [0.0, 1.0] becomes [.1, .9]
# with smoothing .2
y_true = tf.ones(shape) * [0.0, 1.0]
y_true_presmoothed = tf.ones(shape) * [0.1, 0.9]

presmoothed = det_model_fn.focal_loss(y_pred, y_true_presmoothed, alpha,
gamma, n, 0)
unsmoothed = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n,
0.2)
alpha, gamma, n = 1, 0, 100
presmoothed = det_model_fn.focal_loss(y_pred, y_true_presmoothed, alpha,
gamma, n, 0)
alpha, gamma, n = 0.9, 0, 100
unsmoothed = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n, 0.2)

self.assertAllClose(presmoothed, unsmoothed)
self.assertAllClose(presmoothed, unsmoothed)


if __name__ == '__main__':
tf.test.main()
tf.test.main()

0 comments on commit d32ccd0

Please sign in to comment.