Skip to content

Commit

Permalink
Fixed update of centers for center loss
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsandberg committed Mar 31, 2018
1 parent 2ac16d5 commit efc2f94
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/facenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def center_loss(features, label, alfa, nrof_classes):
centers_batch = tf.gather(centers, label)
diff = (1 - alfa) * (centers_batch - features)
centers = tf.scatter_sub(centers, label, diff)
loss = tf.reduce_mean(tf.square(features - centers_batch))
with tf.control_dependencies([centers]):
loss = tf.reduce_mean(tf.square(features - centers_batch))
return loss, centers

def get_image_paths_and_labels(dataset):
Expand Down

0 comments on commit efc2f94

Please sign in to comment.