Skip to content

Commit

Permalink
update postprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed May 28, 2019
1 parent bbebe9a commit 29a5716
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions lanenet_model/lanenet_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
LaneNet model post process
"""
import os.path as ops
import math

import cv2
Expand Down Expand Up @@ -261,6 +262,8 @@ def __init__(self, ipm_remap_file_path='./data/tusimple_ipm_remap.yml'):
:param ipm_remap_file_path: ipm generate file path
"""
assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(ipm_remap_file_path)

self._cluster = _LaneNetCluster()
self._ipm_remap_file_path = ipm_remap_file_path

Expand Down Expand Up @@ -296,13 +299,16 @@ def _load_remap_matrix(self):

return ret

def postprocess(self, binary_seg_result, instance_seg_result=None, min_area_threshold=100, source_image=None):
def postprocess(self, binary_seg_result, instance_seg_result=None,
min_area_threshold=100, source_image=None,
data_source='tusimple'):
"""
:param binary_seg_result:
:param instance_seg_result:
:param min_area_threshold:
:param source_image:
:param data_source:
:return:
"""
# convert binary_seg_result
Expand Down Expand Up @@ -330,9 +336,14 @@ def postprocess(self, binary_seg_result, instance_seg_result=None, min_area_thre
fit_params = []
src_lane_pts = [] # lane pts every single lane
for lane_index, coords in enumerate(lane_coords):
tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255

if data_source == 'tusimple':
tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255
elif data_source == 'beec_ccd':
tmp_mask = np.zeros(shape=(1350, 2448), dtype=np.uint8)
tmp_mask[tuple((np.int_(coords[:, 1] * 1350 / 256), np.int_(coords[:, 0] * 2448 / 512)))] = 255
else:
raise ValueError('Wrong data source now only support tusimple and beec_ccd')
tmp_ipm_mask = cv2.remap(
tmp_mask,
self._remap_to_ipm_x,
Expand All @@ -345,28 +356,38 @@ def postprocess(self, binary_seg_result, instance_seg_result=None, min_area_thre
fit_param = np.polyfit(nonzero_y, nonzero_x, 2)
fit_params.append(fit_param)

plot_y = np.linspace(0, 639, 539)
[ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape
plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10)
fit_x = fit_param[0] * plot_y ** 2 + fit_param[1] * plot_y + fit_param[2]
# fit_x = fit_param[0] * plot_y ** 3 + fit_param[1] * plot_y ** 2 + fit_param[2] * plot_y + fit_param[3]

lane_pts = []
for index in range(0, plot_y.shape[0], 5):
src_x = self._remap_to_ipm_x[int(plot_y[index]), int(np.clip(fit_x[index], 0, 639))]
src_x = self._remap_to_ipm_x[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
if src_x <= 0:
continue
src_y = self._remap_to_ipm_y[int(plot_y[index]), int(np.clip(fit_x[index], 0, 639))]
src_y = self._remap_to_ipm_y[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
src_y = src_y if src_y > 0 else 0

lane_pts.append([src_x, src_y])

src_lane_pts.append(lane_pts)

# tusimple test data sample point along y axis every 10 pixels
source_image_width = source_image.shape[1]
for index, single_lane_pts in enumerate(src_lane_pts):
single_lane_pt_x = np.array(single_lane_pts, dtype=np.float32)[:, 0]
single_lane_pt_y = np.array(single_lane_pts, dtype=np.float32)[:, 1]
start_plot_y = 240
end_plot_y = 720
if data_source == 'tusimple':
start_plot_y = 240
end_plot_y = 720
elif data_source == 'beec_ccd':
start_plot_y = 820
end_plot_y = 1350
else:
raise ValueError('Wrong data source now only support tusimple and beec_ccd')
step = int(math.floor((end_plot_y - start_plot_y) / 10))
for plot_y in np.linspace(start_plot_y, end_plot_y, step):
diff = single_lane_pt_y - plot_y
Expand Down Expand Up @@ -394,7 +415,7 @@ def postprocess(self, binary_seg_result, instance_seg_result=None, min_area_thre
abs(last_src_pt_y - plot_y) * last_src_pt_y) / \
(abs(previous_src_pt_y - plot_y) + abs(last_src_pt_y - plot_y))

if interpolation_src_pt_x > 1280 or interpolation_src_pt_x < 10:
if interpolation_src_pt_x > source_image_width or interpolation_src_pt_x < 10:
continue

lane_color = self._color_map[index].tolist()
Expand Down

0 comments on commit 29a5716

Please sign in to comment.