Skip to content

Commit

Permalink
change accuracy compute function
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed May 31, 2018
1 parent 62f8bab commit afed844
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions tools/train_lanenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,11 @@ def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
out_logits_out = tf.argmax(out_logits, axis=-1)
out = tf.argmax(out_logits, axis=-1)
out = tf.expand_dims(out, axis=-1)
accuracy = tf.add(binary_label_tensor, -1 * out)
accuracy = tf.count_nonzero(accuracy, axis=[1, 2, 3])
accuracy = tf.add(tf.constant(1, dtype=tf.float64),
-1 * tf.divide(accuracy,
CFG.TRAIN.IMG_HEIGHT * CFG.TRAIN.IMG_WIDTH))
accuracy = tf.reduce_mean(accuracy, axis=0)

idx = tf.where(tf.equal(binary_label_tensor, 1))
pix_cls_ret = tf.gather_nd(out, idx)
accuracy = tf.count_nonzero(pix_cls_ret)
accuracy = tf.divide(accuracy, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))

global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step,
Expand Down

0 comments on commit afed844

Please sign in to comment.