Skip to content

Commit

Permalink
modified the class weights compute method
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed May 20, 2019
1 parent 3b5554e commit 5055bbf
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions lanenet_model/lanenet_back_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ def _is_net_for_training(self):

return tf.equal(phase, tf.constant('train', dtype=tf.string))

@staticmethod
def _compute_class_weighted_cross_entropy_loss(onehot_labels, logits, classes_weights):
"""
:param onehot_labels:
:param logits:
:param classes_weights:
:return:
"""
loss_weights = tf.reduce_sum(tf.multiply(onehot_labels, classes_weights), axis=3)

loss = tf.losses.softmax_cross_entropy(
onehot_labels=onehot_labels,
logits=logits,
weights=loss_weights
)

return loss

def compute_loss(self, binary_seg_logits, binary_label,
instance_seg_logits, instance_label,
name, reuse):
Expand All @@ -58,32 +77,22 @@ def compute_loss(self, binary_seg_logits, binary_label,
with tf.variable_scope(name_or_scope=name, reuse=reuse):
# calculate class weighted binary seg loss
with tf.variable_scope(name_or_scope='binary_seg'):
binary_label_plain = tf.reshape(
binary_label,
shape=[binary_label.get_shape().as_list()[0] *
binary_label.get_shape().as_list()[1] *
binary_label.get_shape().as_list()[2] *
binary_label.get_shape().as_list()[3]]
)
unique_labels, unique_id, counts = tf.unique_with_counts(binary_label_plain)
counts = tf.cast(counts, tf.float32)
inverse_weights = tf.divide(1.0,
tf.log(tf.add(tf.divide(counts, tf.reduce_sum(counts)),
tf.constant(1.02))))
binary_label_onehot = tf.one_hot(
tf.reshape(tf.cast(binary_label, tf.int32),
shape=[binary_label.get_shape().as_list()[0],
binary_label.get_shape().as_list()[1],
binary_label.get_shape().as_list()[2]]),
depth=2, axis=-1)
weights = tf.reduce_sum(binary_label_onehot * inverse_weights, axis=3)

binary_segmenatation_loss = tf.losses.softmax_cross_entropy(
tf.reshape(
tf.cast(binary_label, tf.int32),
shape=[binary_label.get_shape().as_list()[0],
binary_label.get_shape().as_list()[1],
binary_label.get_shape().as_list()[2]]),
depth=2,
axis=-1
)

classes_weights = [1.4506131276238088, 21.525424601474068]
binary_segmenatation_loss = self._compute_class_weighted_cross_entropy_loss(
onehot_labels=binary_label_onehot,
logits=binary_seg_logits,
weights=weights
classes_weights=classes_weights
)
binary_segmenatation_loss = tf.reduce_mean(binary_segmenatation_loss)

# calculate class weighted instance seg loss
with tf.variable_scope(name_or_scope='instance_seg'):
Expand Down

0 comments on commit 5055bbf

Please sign in to comment.