Skip to content

Commit

Permalink
use tensorflow 1.10 new batchnormalization api
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Oct 24, 2018
1 parent 97eed8f commit 6c01ede
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 26 deletions.
63 changes: 37 additions & 26 deletions encoder_decoder_model/cnn_basenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
The base convolution neural networks mainly implement some useful cnn functions
"""
import tensorflow as tf
import tensorflow.contrib.layers as tf_layer
import numpy as np


Expand Down Expand Up @@ -313,31 +312,31 @@ def layerbn(inputdata, is_training, name):
:param name:
:return:
"""
def f1():
"""
:return:
"""
# print('batch_normalization: train phase')
return tf_layer.batch_norm(
inputdata, is_training=True,
center=True, updates_collections=None,
scope=name, reuse=False)

def f2():
"""
:return:
"""
# print('batch_normalization: test phase')
return tf_layer.batch_norm(
inputdata, is_training=False,
center=True, updates_collections=None,
scope=name, reuse=True)

output = tf.cond(is_training, f1, f2)

return output
# def f1():
# """
#
# :return:
# """
# # print('batch_normalization: train phase')
# return tf_layer.batch_norm(
# inputdata, is_training=True,
# center=True, updates_collections=None,
# scope=name, reuse=False)
#
# def f2():
# """
#
# :return:
# """
# # print('batch_normalization: test phase')
# return tf_layer.batch_norm(
# inputdata, is_training=False,
# center=True, updates_collections=None,
# scope=name, reuse=True)
#
# output = tf.cond(is_training, f1, f2)

return tf.layers.batch_normalization(inputs=inputdata, training=is_training, name=name)

@staticmethod
def squeeze(inputdata, axis=None, name=None):
Expand Down Expand Up @@ -486,3 +485,15 @@ def f2():

output = tf.cond(is_training, f2, f1)
return output

@staticmethod
def lrelu(inputdata, name, alpha=0.2):
"""
:param inputdata:
:param alpha:
:param name:
:return:
"""
with tf.variable_scope(name):
return tf.nn.relu(inputdata) - alpha * tf.nn.relu(-inputdata)
179 changes: 179 additions & 0 deletions tools/train_lanenet_hnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time : 18-5-21 下午3:04
# @Author : Luo Yao
# @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
# @File : train_lanenet_hnet.py
# @IDE: PyCharm Community Edition
"""
训练LaneNet的HNet部分
"""
import argparse
import os
import os.path as ops
import glob
import time

import tensorflow as tf
import glog as log
import numpy as np
import math
import cv2
try:
from cv2 import cv2
except ImportError:
pass

from config import global_config
from lanenet_model import lanenet_hnet_model
from data_provider import lanenet_hnet_data_processor

CFG = global_config.cfg
VGG_MEAN = [103.939, 116.779, 123.68]


def init_args():
"""
:return:
"""
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', type=str, help='The origin tusimple dataset dir')
parser.add_argument('--weights_path', type=str, help='The pretrained weights path')

return parser.parse_args()


def train_net(dataset_dir, weights_path=None):
"""
:param dataset_dir:
:param weights_path
:return:
"""
assert ops.exists(dataset_dir), '{:s} not exist'.format(dataset_dir)

json_file_list = glob.glob('{:s}/*.json'.format(dataset_dir))
json_file_list = [tmp for tmp in json_file_list if 'test' not in tmp]
if not json_file_list:
log.error('Can not find any suitable label json file')
return
train_dataset = lanenet_hnet_data_processor.DataSet(json_file_list)

input_image_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 64, 128, 3], name='input_image_tensor')
input_label_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 3], name='input_label_tensor')

phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

net = lanenet_hnet_model.LaneNetHNet(phase=phase)

# 计算hnet损失
total_loss = net.compute_loss(input_tensor=input_image_tensor, gt_label_pts=input_label_tensor, name='hnet')

global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step,
5000, 0.96, staircase=True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optimizer = tf.train.AdamOptimizer(learning_rate=
learning_rate).minimize(loss=total_loss,
var_list=tf.trainable_variables(),
global_step=global_step)
# Set tf saver
saver = tf.train.Saver()
model_save_dir = 'model/tusimple_lanenet_hnet'
if not ops.exists(model_save_dir):
os.makedirs(model_save_dir)
train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
model_name = 'tusimple_lanenet_hnet_{:s}.ckpt'.format(str(train_start_time))
model_save_path = ops.join(model_save_dir, model_name)

# Set tf summary
tboard_save_path = 'tboard/tusimple_lanenet_hnet/hnet'
if not ops.exists(tboard_save_path):
os.makedirs(tboard_save_path)
train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss)
learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate)
train_merge_summary_op = tf.summary.merge([train_cost_scalar, learning_rate_scalar])

summary_writer = tf.summary.FileWriter(tboard_save_path)

# Set sess configuration
sess_config = tf.ConfigProto(device_count={'GPU': 1})
sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
sess_config.gpu_options.allocator_type = 'BFC'

sess = tf.Session(config=sess_config)

summary_writer.add_graph(sess.graph)

# Set the training parameters
train_epochs = CFG.TRAIN.EPOCHS

log.info('Global configuration is as follows:')
log.info(CFG)

with sess.as_default():

tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='',
name='{:s}/lanenet_hnet.pb'.format(model_save_dir))

if weights_path is None:
log.info('Training from scratch')
init = tf.global_variables_initializer()
sess.run(init)
else:
log.info('Restore model from last model check point{:s}'.format(weights_path))
saver.restore(sess=sess, save_path=weights_path)

train_cost_time_mean = []
for epoch in range(train_epochs):
# training part
t_start = time.time()

gt_imgs, gt_pts_labels = train_dataset.next_batch(1)
gt_imgs = [cv2.resize(tmp,
dsize=(128, 64),
dst=tmp,
interpolation=cv2.INTER_LINEAR)
for tmp in gt_imgs]
gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs]

lane_index = 0
for lane_pts in gt_pts_labels[0]:
lane_pts = np.concatenate((lane_pts, np.ones(shape=[len(lane_pts), 1])), axis=1)

phase_train = 'train'

_, c, train_summary = sess.run([optimizer, total_loss, train_merge_summary_op],
feed_dict={input_image_tensor: gt_imgs,
input_label_tensor: lane_pts,
phase: phase_train})
if math.isnan(c):
log.info('Loss error: loss is nan')
return

cost_time = time.time() - t_start
train_cost_time_mean.append(cost_time)
summary_writer.add_summary(summary=train_summary, global_step=epoch)

if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
log.info('[Epoch/Lane]: [{:d}/{:d}] total_loss= {:6f} cost_time= {:.6f}'.
format(epoch + 1, lane_index + 1, c, np.mean(train_cost_time_mean)))
train_cost_time_mean.clear()

if epoch % 2000 == 0:
saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
lane_index += 1
sess.close()

return


if __name__ == '__main__':
# init args
args = init_args()

# train hnet
train_net(dataset_dir=args.dataset_dir)

0 comments on commit 6c01ede

Please sign in to comment.