Skip to content
This repository has been archived by the owner on Jun 28, 2024. It is now read-only.

Commit

Permalink
center loss delete
Browse files Browse the repository at this point in the history
  • Loading branch information
rockyzhengwu committed Oct 27, 2019
1 parent 82ee67f commit 870467d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions single_word_ocr/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(self, batch_size, num_classes, mode='train', center_loss_alpha=0.95

if self.is_training:
print("feautres =====> ", self.features)
self.center_loss, _ = self.center_loss(self.features, self.labels, self.center_loss_alpha, self.num_classes)
self.softmax_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.labels))
self.loss = self.softmax_loss + self.center_loss
#self.center_loss, _ = self.center_loss(self.features, self.labels, self.center_loss_alpha, self.num_classes)
#self.softmax_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.labels))
self.loss = self.softmax_loss

self.predict_prob = tf.nn.softmax(self.logits, name='prob' )
print("predict_prob====>", self.predict_prob)
Expand Down
4 changes: 2 additions & 2 deletions single_word_ocr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def train():
#optimizer=tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9).minimize(loss=model.loss)
saver = tf.train.Saver()
tf.summary.scalar(name='loss', tensor=model.loss)
tf.summary.scalar(name='softmax_loss', tensor=model.softmax_loss)
tf.summary.scalar(name='center_loss', tensor=model.center_loss)
#tf.summary.scalar(name='softmax_loss', tensor=model.softmax_loss)
#tf.summary.scalar(name='center_loss', tensor=model.center_loss)
tf.summary.scalar(name='accuracy', tensor=model.accuracy)
merge_summary_op = tf.summary.merge_all()
sess_config = tf.ConfigProto(allow_soft_placement=True,)
Expand Down

0 comments on commit 870467d

Please sign in to comment.