diff --git a/efficientdet/det_model_fn.py b/efficientdet/det_model_fn.py index 5b6bf84b6..14c474b03 100644 --- a/efficientdet/det_model_fn.py +++ b/efficientdet/det_model_fn.py @@ -142,24 +142,7 @@ def learning_rate_schedule(params, global_step): raise ValueError('unknown lr_decay_method: {}'.format(lr_decay_method)) -def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _=0): - """A legacy focal loss that does not support label smoothing.""" - with tf.name_scope('focal_loss'): - positive_label_mask = tf.equal(targets, 1.0) - cross_entropy = ( - tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits)) - - neg_logits = -1.0 * logits - modulator = tf.exp(gamma * targets * neg_logits - - gamma * tf.log1p(tf.exp(neg_logits))) - loss = modulator * cross_entropy - weighted_loss = tf.where(positive_label_mask, alpha * loss, - (1.0 - alpha) * loss) - weighted_loss /= normalizer - return weighted_loss - - -def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing=0): +def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing=0.0): """Compute the focal loss between `logits` and the golden `target` values. Focal loss = -(1-pt)^gamma * log(pt) @@ -183,17 +166,17 @@ def focal_loss(y_pred, y_true, alpha, gamma, normalizer, label_smoothing=0): alpha = tf.convert_to_tensor(alpha, dtype=y_pred.dtype) gamma = tf.convert_to_tensor(gamma, dtype=y_pred.dtype) - # apply label smoothing. - y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing - - # get cross_entropy for each entry. - ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred) - + # compute focal loss multipliers before label smoothing, such that it will + # not blow up the loss. pred_prob = tf.sigmoid(y_pred) p_t = (y_true * pred_prob) + ((1 - y_true) * (1 - pred_prob)) alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha) modulating_factor = (1.0 - p_t) ** gamma + # apply label smoothing for cross_entropy for each entry. + y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing + ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred) + # compute the final loss and return return alpha_factor * modulating_factor * ce / normalizer diff --git a/efficientdet/det_model_fn_test.py b/efficientdet/det_model_fn_test.py index 924a47f89..160468426 100644 --- a/efficientdet/det_model_fn_test.py +++ b/efficientdet/det_model_fn_test.py @@ -15,38 +15,55 @@ # ============================================================================== """Tests for det_model_fn.""" import tensorflow as tf - import det_model_fn -class FocalLossTest(tf.test.TestCase): +def legacy_focal_loss(logits, targets, alpha, gamma, normalizer, _=0): + """A legacy focal loss that does not support label smoothing.""" + with tf.name_scope('focal_loss'): + positive_label_mask = tf.equal(targets, 1.0) + cross_entropy = ( + tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits)) + + neg_logits = -1.0 * logits + modulator = tf.exp(gamma * targets * neg_logits - + gamma * tf.math.log1p(tf.exp(neg_logits))) + loss = modulator * cross_entropy + weighted_loss = tf.where(positive_label_mask, alpha * loss, + (1.0 - alpha) * loss) + weighted_loss /= normalizer + return weighted_loss - def test_focal_loss(self): - y_pred = tf.random.uniform([4, 32, 32, 90]) - y_true = tf.ones([4, 32, 32, 90]) - alpha, gamma, n = 0.25, 1.5, 100 - legacy_output = det_model_fn.legacy_focal_loss(y_pred, y_true, alpha, - gamma, n) - new_output = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n) - self.assertAllClose(legacy_output, new_output) - def test_focal_loss_with_label_smoothing(self): - y_pred = tf.random.uniform([4, 32, 32, 2]) +class FocalLossTest(tf.test.TestCase): + + def test_focal_loss(self): + tf.random.set_seed(1111) + y_pred = tf.random.uniform([4, 32, 32, 90]) + y_true = tf.ones([4, 32, 32, 90]) + alpha, gamma, n = 0.25, 1.5, 100 + legacy_output = legacy_focal_loss(y_pred, y_true, alpha, gamma, n) + new_output = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n) + self.assertAllClose(legacy_output, new_output) - # A binary classification target [0.0, 1.0] becomes [.1, .9] - # with smoothing .2 - y_true = tf.ones([4, 32, 32, 2]) * [0.0, 1.0] - y_true_presmoothed = tf.ones([4, 32, 32, 2]) * [0.1, 0.9] + def test_focal_loss_with_label_smoothing(self): + tf.random.set_seed(1111) + shape = [2, 2, 2, 2] + y_pred = tf.random.uniform(shape) - alpha, gamma, n = 0.25, 1.5, 100 + # A binary classification target [0.0, 1.0] becomes [.1, .9] + # with smoothing .2 + y_true = tf.ones(shape) * [0.0, 1.0] + y_true_presmoothed = tf.ones(shape) * [0.1, 0.9] - presmoothed = det_model_fn.focal_loss(y_pred, y_true_presmoothed, alpha, - gamma, n, 0) - unsmoothed = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n, - 0.2) + alpha, gamma, n = 1, 0, 100 + presmoothed = det_model_fn.focal_loss(y_pred, y_true_presmoothed, alpha, + gamma, n, 0) + alpha, gamma, n = 0.9, 0, 100 + unsmoothed = det_model_fn.focal_loss(y_pred, y_true, alpha, gamma, n, 0.2) - self.assertAllClose(presmoothed, unsmoothed) + self.assertAllClose(presmoothed, unsmoothed) if __name__ == '__main__': - tf.test.main() + tf.test.main()