Skip to content

Commit

Permalink
fix code style problem
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed May 28, 2019
1 parent b9a6c2f commit 4ce104e
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tools/train_lanenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,15 @@ def train_lanenet(dataset_dir, weights_path=None, net_flag='vgg'):

# Set tf model save path
model_save_dir = 'model/tusimple_lanenet_{:s}'.format(net_flag)
# if not ops.exists(model_save_dir):
# os.makedirs(model_save_dir)
os.makedirs(model_save_dir, exist_ok=True)
train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
model_name = 'tusimple_lanenet_{:s}_{:s}.ckpt'.format(net_flag, str(train_start_time))
model_save_path = ops.join(model_save_dir, model_name)
saver = tf.train.Saver()

# Set tf summary save path
tboard_save_path = 'tboard/tusimple_lanenet_{:s}'.format(net_flag)
if not ops.exists(tboard_save_path):
os.makedirs(tboard_save_path)
os.makedirs(tboard_save_path, exist_ok=True)

# Set sess configuration
sess_config = tf.ConfigProto(allow_soft_placement=True)
Expand Down Expand Up @@ -425,9 +423,10 @@ def train_lanenet(dataset_dir, weights_path=None, net_flag='vgg'):
# training part
t_start = time.time()

_, train_c, train_accuracy_figure, train_fn_figure, train_fp_figure, lr, train_summary, train_binary_loss, \
train_instance_loss, train_embeddings, train_binary_seg_imgs, train_gt_imgs, \
train_binary_gt_labels, train_instance_gt_labels = \
_, train_c, train_accuracy_figure, train_fn_figure, train_fp_figure, \
lr, train_summary, train_binary_loss, \
train_instance_loss, train_embeddings, train_binary_seg_imgs, train_gt_imgs, \
train_binary_gt_labels, train_instance_gt_labels = \
sess.run([optimizer, train_total_loss, train_accuracy, train_fn, train_fp,
learning_rate, train_merge_summary_op, train_binary_seg_loss,
train_disc_loss, train_pix_embedding, train_prediction,
Expand Down Expand Up @@ -456,9 +455,10 @@ def train_lanenet(dataset_dir, weights_path=None, net_flag='vgg'):
train_cost_time_mean.clear()

# validation part
val_c, val_accuracy_figure, val_fn_figure, val_fp_figure, val_summary, val_binary_loss, \
val_instance_loss, val_embeddings, val_binary_seg_imgs, val_gt_imgs, \
val_binary_gt_labels, val_instance_gt_labels = \
val_c, val_accuracy_figure, val_fn_figure, val_fp_figure, \
val_summary, val_binary_loss, val_instance_loss, \
val_embeddings, val_binary_seg_imgs, val_gt_imgs, \
val_binary_gt_labels, val_instance_gt_labels = \
sess.run([val_total_loss, val_accuracy, val_fn, val_fp,
val_merge_summary_op, val_binary_seg_loss,
val_disc_loss, val_pix_embedding, val_prediction,
Expand Down

0 comments on commit 4ce104e

Please sign in to comment.