Skip to content

Commit

Permalink
add post process on binary segmenatation result during inference time
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed May 31, 2018
1 parent afed844 commit 6998985
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions tools/test_lanenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from lanenet_model import lanenet_merge_model
from lanenet_model import lanenet_cluster
from lanenet_model import lanenet_postprocess
from config import global_config

CFG = global_config.cfg
Expand Down Expand Up @@ -74,6 +75,7 @@ def test_lanenet(image_path, weights_path, use_gpu):
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_loss')

cluster = lanenet_cluster.LaneNetCluster()
postprocessor = lanenet_postprocess.LaneNetPoseProcessor()

saver = tf.train.Saver()

Expand All @@ -98,6 +100,7 @@ def test_lanenet(image_path, weights_path, use_gpu):
t_cost = time.time() - t_start
log.info('单张图像车道线预测耗时: {:.5f}s'.format(t_cost))

binary_seg_image[0] = postprocessor.postprocess(binary_seg_image[0])
mask_image = cluster.get_lane_mask(binary_seg_ret=binary_seg_image[0],
instance_seg_ret=instance_seg_image[0])
# mask_image = cluster.get_lane_mask_v2(instance_seg_ret=instance_seg_image[0])
Expand All @@ -118,6 +121,9 @@ def test_lanenet(image_path, weights_path, use_gpu):
# mask_image = cv2.resize(mask_image, (image_vis.shape[1], image_vis.shape[0]),
# interpolation=cv2.INTER_LINEAR)

cv2.imwrite('binary_ret.png', binary_seg_image[0] * 255)
cv2.imwrite('instance_ret.png', embedding_image)

plt.figure('mask_image')
plt.imshow(mask_image[:, :, (2, 1, 0)])
plt.figure('src_image')
Expand Down Expand Up @@ -157,6 +163,7 @@ def test_lanenet_batch(image_dir, weights_path, batch_size, use_gpu, save_dir=No
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_loss')

cluster = lanenet_cluster.LaneNetCluster()
postprocessor = lanenet_postprocess.LaneNetPoseProcessor()

saver = tf.train.Saver()

Expand Down Expand Up @@ -187,7 +194,8 @@ def test_lanenet_batch(image_dir, weights_path, batch_size, use_gpu, save_dir=No
for tmp in image_list_epoch]
image_list_epoch = [tmp - VGG_MEAN for tmp in image_list_epoch]
t_cost = time.time() - t_start
log.info('[Epoch:{:d}] 图像预处理耗时: {:.5f}s'.format(epoch, t_cost))
log.info('[Epoch:{:d}] 预处理{:d}张图像, 共耗时: {:.5f}s, 平均每张耗时: {:.5f}'.format(
epoch, len(image_path_epoch), t_cost, t_cost / len(image_path_epoch)))

t_start = time.time()
binary_seg_images, instance_seg_images = sess.run(
Expand All @@ -196,9 +204,13 @@ def test_lanenet_batch(image_dir, weights_path, batch_size, use_gpu, save_dir=No
log.info('[Epoch:{:d}] 预测{:d}张图像车道线, 共耗时: {:.5f}s, 平均每张耗时: {:.5f}s'.format(
epoch, len(image_path_epoch), t_cost, t_cost / len(image_path_epoch)))

cluster_time = []
for index, binary_seg_image in enumerate(binary_seg_images):
t_start = time.time()
binary_seg_image = postprocessor.postprocess(binary_seg_image)
mask_image = cluster.get_lane_mask(binary_seg_ret=binary_seg_image,
instance_seg_ret=instance_seg_images[index])
cluster_time.append(time.time() - t_start)
mask_image = cv2.resize(mask_image, (image_vis_list[index].shape[1],
image_vis_list[index].shape[0]),
interpolation=cv2.INTER_LINEAR)
Expand All @@ -213,13 +225,14 @@ def test_lanenet_batch(image_dir, weights_path, batch_size, use_gpu, save_dir=No
plt.show()
plt.ioff()

mask_image = cv2.addWeighted(image_vis_list[index], 1.0, mask_image, 1.0, 0)

if save_dir is not None:
mask_image = cv2.addWeighted(image_vis_list[index], 1.0, mask_image, 1.0, 0)
image_name = ops.split(image_path_epoch[index])[1]
image_save_path = ops.join(save_dir, image_name)
cv2.imwrite(image_save_path, mask_image)
log.info('[Epoch:{:d}] Detection image {:s} complete'.format(epoch, image_name))
# log.info('[Epoch:{:d}] Detection image {:s} complete'.format(epoch, image_name))
log.info('[Epoch:{:d}] 进行{:d}张图像车道线聚类, 共耗时: {:.5f}s, 平均每张耗时: {:.5f}'.format(
epoch, len(image_path_epoch), np.sum(cluster_time), np.mean(cluster_time)))

sess.close()

Expand Down

0 comments on commit 6998985

Please sign in to comment.