forked from MaybeShewill-CV/lanenet-lane-detection
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluate_lanenet_on_tusimple.py
127 lines (97 loc) · 3.85 KB
/
evaluate_lanenet_on_tusimple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time : 19-5-16 下午6:26
# @Author : MaybeShewill-CV
# @Site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
# @File : evaluate_lanenet_on_tusimple.py
# @IDE: PyCharm
"""
Evaluate lanenet model on tusimple lane dataset
"""
import argparse
import glob
import os
import os.path as ops
import time
import cv2
import numpy as np
import tensorflow as tf
import tqdm
from lanenet_model import lanenet
from lanenet_model import lanenet_postprocess
from local_utils.config_utils import parse_config_utils
from local_utils.log_util import init_logger
CFG = parse_config_utils.lanenet_cfg
LOG = init_logger.get_logger(log_file_name_prefix='lanenet_eval')
def init_args():
"""
:return:
"""
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, help='The source tusimple lane test data dir')
parser.add_argument('--weights_path', type=str, help='The model weights path')
parser.add_argument('--save_dir', type=str, help='The test output save root dir')
return parser.parse_args()
def eval_lanenet(src_dir, weights_path, save_dir):
"""
:param src_dir:
:param weights_path:
:param save_dir:
:return:
"""
assert ops.exists(src_dir), '{:s} not exist'.format(src_dir)
os.makedirs(save_dir, exist_ok=True)
input_tensor = tf.compat.v1.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')
net = lanenet.LaneNet(phase='test')
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='LaneNet')
postprocessor = lanenet_postprocess.LaneNetPostProcessor()
saver = tf.compat.v1.train.Saver()
# Set sess configuration
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.GPU.GPU_MEMORY_FRACTION
sess_config.gpu_options.allow_growth = CFG.GPU.TF_ALLOW_GROWTH
sess_config.gpu_options.allocator_type = 'BFC'
sess = tf.compat.v1.Session(config=sess_config)
with sess.as_default():
saver.restore(sess=sess, save_path=weights_path)
image_list = glob.glob('{:s}/**/*.jpg'.format(src_dir), recursive=True)
avg_time_cost = []
for index, image_path in tqdm.tqdm(enumerate(image_list), total=len(image_list)):
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
image_vis = image
image = cv2.resize(image, (512, 256), interpolation=cv2.INTER_LINEAR)
image = image / 127.5 - 1.0
t_start = time.time()
binary_seg_image, instance_seg_image = sess.run(
[binary_seg_ret, instance_seg_ret],
feed_dict={input_tensor: [image]}
)
avg_time_cost.append(time.time() - t_start)
postprocess_result = postprocessor.postprocess(
binary_seg_result=binary_seg_image[0],
instance_seg_result=instance_seg_image[0],
source_image=image_vis
)
if index % 100 == 0:
LOG.info('Mean inference time every single image: {:.5f}s'.format(np.mean(avg_time_cost)))
avg_time_cost.clear()
input_image_dir = ops.split(image_path.split('clips')[1])[0][1:]
input_image_name = ops.split(image_path)[1]
output_image_dir = ops.join(save_dir, input_image_dir)
os.makedirs(output_image_dir, exist_ok=True)
output_image_path = ops.join(output_image_dir, input_image_name)
if ops.exists(output_image_path):
continue
cv2.imwrite(output_image_path, postprocess_result['source_image'])
return
if __name__ == '__main__':
"""
test code
"""
# init args
args = init_args()
eval_lanenet(
src_dir=args.image_dir,
weights_path=args.weights_path,
save_dir=args.save_dir
)