Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Dec 4, 2019
1 parent bb5e9c8 commit be4d3c9
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 12 deletions.
79 changes: 70 additions & 9 deletions data_provider/tf_io_pipline_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,32 @@ def write_example_tfrecords(gt_images_paths, gt_binary_images_paths, gt_instance
# prepare gt image
_gt_image = cv2.imread(_gt_image_path, cv2.IMREAD_UNCHANGED)
if _gt_image.shape != (RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT, 3):
_gt_image = cv2.resize(_gt_image,
dsize=(RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT),
interpolation=cv2.INTER_LINEAR)
_gt_image = cv2.resize(
_gt_image,
dsize=(RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT),
interpolation=cv2.INTER_LINEAR
)
_gt_image_raw = _gt_image.tostring()

# prepare gt binary image
_gt_binary_image = cv2.imread(gt_binary_images_paths[_index], cv2.IMREAD_UNCHANGED)
if _gt_binary_image.shape != (RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT):
_gt_binary_image = cv2.resize(_gt_binary_image,
dsize=(RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT),
interpolation=cv2.INTER_NEAREST)
_gt_binary_image = cv2.resize(
_gt_binary_image,
dsize=(RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT),
interpolation=cv2.INTER_NEAREST
)
_gt_binary_image = np.array(_gt_binary_image / 255.0, dtype=np.uint8)
_gt_binary_image_raw = _gt_binary_image.tostring()

# prepare gt instance image
_gt_instance_image = cv2.imread(gt_instance_images_paths[_index], cv2.IMREAD_UNCHANGED)
if _gt_instance_image.shape != (RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT):
_gt_instance_image = cv2.resize(_gt_instance_image,
dsize=(RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT),
interpolation=cv2.INTER_NEAREST)
_gt_instance_image = cv2.resize(
_gt_instance_image,
dsize=(RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT),
interpolation=cv2.INTER_NEAREST
)
_gt_instance_image_raw = _gt_instance_image.tostring()

_example = tf.train.Example(
Expand Down Expand Up @@ -132,6 +138,25 @@ def decode(serialized_example):
return gt_image, gt_binary_image, gt_instance_image


def central_crop(image, crop_height, crop_width):
"""
Performs central crops of the given image
:param image:
:param crop_height:
:param crop_width:
:return:
"""
shape = tf.shape(input=image)
height, width = shape[0], shape[1]

amount_to_be_cropped_h = (height - crop_height)
crop_top = amount_to_be_cropped_h // 2
amount_to_be_cropped_w = (width - crop_width)
crop_left = amount_to_be_cropped_w // 2

return tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])


def augment_for_train(gt_image, gt_binary_image, gt_instance_image):
"""
Expand All @@ -145,6 +170,11 @@ def augment_for_train(gt_image, gt_binary_image, gt_instance_image):
gt_binary_image = tf.cast(gt_binary_image, tf.float32)
gt_instance_image = tf.cast(gt_instance_image, tf.float32)

# apply random color augmentation
gt_image, gt_binary_image, gt_instance_image = random_color_augmentation(
gt_image, gt_binary_image, gt_instance_image
)

# apply random flip augmentation
gt_image, gt_binary_image, gt_instance_image = random_horizon_flip_batch_images(
gt_image, gt_binary_image, gt_instance_image
Expand All @@ -167,6 +197,17 @@ def augment_for_test(gt_image, gt_binary_image, gt_instance_image):
:param gt_instance_image:
:return:
"""
# apply central crop
gt_image = central_crop(
image=gt_image, crop_height=CROP_IMAGE_HEIGHT, crop_width=CROP_IMAGE_WIDTH
)
gt_binary_image = central_crop(
image=gt_binary_image, crop_height=CROP_IMAGE_HEIGHT, crop_width=CROP_IMAGE_WIDTH
)
gt_instance_image = central_crop(
image=gt_instance_image, crop_height=CROP_IMAGE_HEIGHT, crop_width=CROP_IMAGE_WIDTH
)

return gt_image, gt_binary_image, gt_instance_image


Expand Down Expand Up @@ -263,3 +304,23 @@ def random_horizon_flip_batch_images(gt_image, gt_binary_image, gt_instance_imag
)

return flipped_gt_image, flipped_gt_binary_image, flipped_gt_instance_image


def random_color_augmentation(gt_image, gt_binary_image, gt_instance_image):
"""
andom color augmentation
:param gt_image:
:param gt_binary_image:
:param gt_instance_image:
:return:
"""
# first apply random saturation augmentation
gt_image = tf.image.random_saturation(gt_image, 0.8, 1.2)
# sencond apply random brightness augmentation
gt_image = tf.image.random_brightness(gt_image, 0.05)
# third apply random contrast augmentation
gt_image = tf.image.random_contrast(gt_image, 0.7, 1.3)

gt_image = tf.clip_by_value(gt_image, 0.0, 255.0)

return gt_image, gt_binary_image, gt_instance_image
6 changes: 3 additions & 3 deletions lanenet_model/lanenet_discriminative_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def discriminative_loss_single(
mu = tf.div(segmented_sum, tf.reshape(counts, (-1, 1)))
mu_expand = tf.gather(mu, unique_id)

distance = tf.norm(tf.subtract(mu_expand, reshaped_pred), axis=1)
distance = tf.norm(tf.subtract(mu_expand, reshaped_pred), axis=1, ord=1)
distance = tf.subtract(distance, delta_v)
distance = tf.clip_by_value(distance, 0., distance)
distance = tf.square(distance)
Expand All @@ -76,14 +76,14 @@ def discriminative_loss_single(
bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
mu_diff_bool = tf.boolean_mask(mu_diff, bool_mask)

mu_norm = tf.norm(mu_diff_bool, axis=1)
mu_norm = tf.norm(mu_diff_bool, axis=1, ord=1)
mu_norm = tf.subtract(2. * delta_d, mu_norm)
mu_norm = tf.clip_by_value(mu_norm, 0., mu_norm)
mu_norm = tf.square(mu_norm)

l_dist = tf.reduce_mean(mu_norm)

l_reg = tf.reduce_mean(tf.norm(mu, axis=1))
l_reg = tf.reduce_mean(tf.norm(mu, axis=1, ord=1))

param_scale = 1.
l_var = param_var * l_var
Expand Down

0 comments on commit be4d3c9

Please sign in to comment.