Skip to content

Commit

Permalink
Update lanenet_back_end.py
Browse files Browse the repository at this point in the history
add focal loss.
  • Loading branch information
pauperonway authored Jul 9, 2020
1 parent c1246f4 commit b7c431c
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions lanenet_model/lanenet_back_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,20 @@ def _multi_category_focal_loss(cls, onehot_labels, logits, classes_weights, gamm
:param gamma:
:return:
"""
raise NotImplementedError('Func has not been implemented')
epsilon = 1.e-7
alpha = tf.multiply(onehot_labels, classes_weights)
alpha = tf.cast(alpha, tf.float32)
gamma = float(gamma)
y_true = tf.cast(onehot_labels, tf.float32)
y_pred = tf.nn.softmax(logits, dim=-1)
y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
ce = -tf.log(y_t)
weight = tf.pow(tf.subtract(1., y_t), gamma)
fl = tf.multiply(tf.multiply(weight, ce), alpha)
loss = tf.reduce_mean(fl)

return loss

def compute_loss(self, binary_seg_logits, binary_label,
instance_seg_logits, instance_label,
Expand Down Expand Up @@ -122,7 +135,11 @@ def compute_loss(self, binary_seg_logits, binary_label,
classes_weights=inverse_weights
)
elif self._binary_loss_type == 'focal':
raise NotImplementedError
binary_segmenatation_loss = self._multi_category_focal_loss(
onehot_labels=binary_label_onehot,
logits=binary_seg_logits,
classes_weights=inverse_weights
)
else:
raise NotImplementedError

Expand Down

0 comments on commit b7c431c

Please sign in to comment.