Skip to content

Commit

Permalink
update lanenet eval tools
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Jun 15, 2020
1 parent f1d12fe commit de1ce80
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions tools/evaluate_lanenet_on_tusimple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
import time

import cv2
import glog as log
import numpy as np
import tensorflow as tf
import tqdm

from config import global_config
from lanenet_model import lanenet
from lanenet_model import lanenet_postprocess
from local_utils.config_utils import parse_config_utils
from local_utils.log_util import init_logger

CFG = global_config.cfg
CFG = parse_config_utils.lanenet_cfg
LOG = init_logger.get_logger(log_file_name_prefix='lanenet_eval')


def init_args():
Expand All @@ -40,7 +41,7 @@ def init_args():
return parser.parse_args()


def test_lanenet_batch(src_dir, weights_path, save_dir):
def eval_lanenet(src_dir, weights_path, save_dir):
"""
:param src_dir:
Expand All @@ -54,17 +55,17 @@ def test_lanenet_batch(src_dir, weights_path, save_dir):

input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')

net = lanenet.LaneNet(phase='test', net_flag='vgg')
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')
net = lanenet.LaneNet(phase='test')
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='LaneNet')

postprocessor = lanenet_postprocess.LaneNetPostProcessor()

saver = tf.train.Saver()

# Set sess configuration
sess_config = tf.ConfigProto()
sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.GPU.GPU_MEMORY_FRACTION
sess_config.gpu_options.allow_growth = CFG.GPU.TF_ALLOW_GROWTH
sess_config.gpu_options.allocator_type = 'BFC'

sess = tf.Session(config=sess_config)
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_lanenet_batch(src_dir, weights_path, save_dir):
)

if index % 100 == 0:
log.info('Mean inference time every single image: {:.5f}s'.format(np.mean(avg_time_cost)))
LOG.info('Mean inference time every single image: {:.5f}s'.format(np.mean(avg_time_cost)))
avg_time_cost.clear()

input_image_dir = ops.split(image_path.split('clips')[1])[0][1:]
Expand All @@ -119,7 +120,7 @@ def test_lanenet_batch(src_dir, weights_path, save_dir):
# init args
args = init_args()

test_lanenet_batch(
eval_lanenet(
src_dir=args.image_dir,
weights_path=args.weights_path,
save_dir=args.save_dir
Expand Down

0 comments on commit de1ce80

Please sign in to comment.