Skip to content

Commit

Permalink
modified test lanenet
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Oct 24, 2018
1 parent 4e9b59a commit adae011
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 202 deletions.
49 changes: 26 additions & 23 deletions tools/test_lanenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ def init_args():
return parser.parse_args()


def minmax_scale(input_arr):
"""
:param input_arr:
:return:
"""
min_val = np.min(input_arr)
max_val = np.max(input_arr)

output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

return output_arr


def test_lanenet(image_path, weights_path, use_gpu):
"""
Expand All @@ -69,10 +83,10 @@ def test_lanenet(image_path, weights_path, use_gpu):
log.info('图像读取完毕, 耗时: {:.5f}s'.format(time.time() - t_start))

input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')
phase_tensor = tf.constant('train', tf.string)
phase_tensor = tf.constant('test', tf.string)

net = lanenet_merge_model.LaneNet(phase=phase_tensor, net_flag='vgg')
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_loss')
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')

cluster = lanenet_cluster.LaneNetCluster()
postprocessor = lanenet_postprocess.LaneNetPoseProcessor()
Expand Down Expand Up @@ -103,26 +117,10 @@ def test_lanenet(image_path, weights_path, use_gpu):
binary_seg_image[0] = postprocessor.postprocess(binary_seg_image[0])
mask_image = cluster.get_lane_mask(binary_seg_ret=binary_seg_image[0],
instance_seg_ret=instance_seg_image[0])
# mask_image = cluster.get_lane_mask_v2(instance_seg_ret=instance_seg_image[0])
# mask_image = cv2.resize(mask_image, (image_vis.shape[1], image_vis.shape[0]),
# interpolation=cv2.INTER_LINEAR)

ele_mex = np.max(instance_seg_image[0], axis=(0, 1))
for i in range(3):
if ele_mex[i] == 0:
scale = 1
else:
scale = 255 / ele_mex[i]
instance_seg_image[0][:, :, i] *= int(scale)
embedding_image = np.array(instance_seg_image[0], np.uint8)
# cv2.imwrite('embedding_mask.png', embedding_image)

# mask_image = cluster.get_lane_mask_v2(instance_seg_ret=embedding_image)
# mask_image = cv2.resize(mask_image, (image_vis.shape[1], image_vis.shape[0]),
# interpolation=cv2.INTER_LINEAR)

cv2.imwrite('binary_ret.png', binary_seg_image[0] * 255)
cv2.imwrite('instance_ret.png', embedding_image)
for i in range(4):
instance_seg_image[0][:, :, i] = minmax_scale(instance_seg_image[0][:, :, i])
embedding_image = np.array(instance_seg_image[0], np.uint8)

plt.figure('mask_image')
plt.imshow(mask_image[:, :, (2, 1, 0)])
Expand Down Expand Up @@ -157,10 +155,10 @@ def test_lanenet_batch(image_dir, weights_path, batch_size, use_gpu, save_dir=No
glob.glob('{:s}/**/*.jpeg'.format(image_dir), recursive=True)

input_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 256, 512, 3], name='input_tensor')
phase_tensor = tf.constant('train', tf.string)
phase_tensor = tf.constant('test', tf.string)

net = lanenet_merge_model.LaneNet(phase=phase_tensor, net_flag='vgg')
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_loss')
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')

cluster = lanenet_cluster.LaneNetCluster()
postprocessor = lanenet_postprocess.LaneNetPoseProcessor()
Expand Down Expand Up @@ -254,3 +252,8 @@ def test_lanenet_batch(image_dir, weights_path, batch_size, use_gpu, save_dir=No
# test hnet model on a batch of image
test_lanenet_batch(image_dir=args.image_path, weights_path=args.weights_path,
save_dir=args.save_dir, use_gpu=args.use_gpu, batch_size=args.batch_size)

# test_net_hdmap('/media/baidu/Data/高精图像质检/20180716t162624_hb/20180716T164118',
# '/home/baidu/Silly_Project/ICode/baidu/beec/lanenet-lane-detection/model/'
# 'tusimple_lanenet/tusimple_lanenet_vgg_2018-05-21-11-11-03.ckpt-94000',
# '/media/baidu/Data/高精图像质检/20180716t162624_hb/lanenet_mask_ret')
179 changes: 0 additions & 179 deletions tools/train_lanenet_hnet.py

This file was deleted.

0 comments on commit adae011

Please sign in to comment.