Skip to content

Commit

Permalink
instance loss
Browse files Browse the repository at this point in the history
  • Loading branch information
stesha.chen committed Apr 8, 2019
1 parent de6b089 commit 0b4b54e
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 48 deletions.
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
.PyCharm2017.3/
.PyCharm2017.3/
data_provider/*.pyc
encoder_decoder_model/*.pyc
lanenet_model/*.pyc
config/*.pyc
data/tusimple_data/
*.png
tboard/*
model/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ script on your own.

## TODO
- [ ] Add Enet backbone for encoder and decoder.
- [ ] Add Enet binary and instance loss.
- [x] Training the model on different dataset
- ~~[ ] Adjust the lanenet hnet model and merge the hnet model to the main lanenet model~~
- [ ] Change the normalization function from BN to GN
Empty file added config/__init__.py
Empty file.
Empty file added data_provider/__init__.py
Empty file.
1 change: 0 additions & 1 deletion data_provider/lanenet_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def next_batch(self, batch_size):

for gt_img_path in gt_img_list:
gt_imgs.append(cv2.imread(gt_img_path, cv2.IMREAD_COLOR))

for gt_label_path in gt_label_binary_list:
label_img = cv2.imread(gt_label_path, cv2.IMREAD_COLOR)
label_binary = np.zeros([label_img.shape[0], label_img.shape[1]], dtype=np.uint8)
Expand Down
2 changes: 1 addition & 1 deletion encoder_decoder_model/dense_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, l, n, growthrate, phase, with_bc=False,
self._growthrate = growthrate
self._with_bc = with_bc
self._phase = phase
self._train_phase = tf.constant('train', dtype=tf.string)
self._train_phase = tf.constant(True, dtype=tf.bool)
self._test_phase = tf.constant('test', dtype=tf.string)
self._is_training = self._init_phase()
self._bc_theta = bc_theta
Expand Down
2 changes: 1 addition & 1 deletion encoder_decoder_model/fcn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, phase):
"""
super(FCNDecoder, self).__init__()
self._train_phase = tf.constant('train', dtype=tf.string)
self._train_phase = tf.constant(True, dtype=tf.bool)
self._phase = phase
self._is_training = self._init_phase()

Expand Down
6 changes: 6 additions & 0 deletions lanenet_model/lanenet_discriminative_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,20 @@ def discriminative_loss_single(
label_shape[1] * label_shape[0], feature_dim])

# 统计实例个数
# unique_labels统计出correct_label中一共有几种数值,unique_id为correct_label中的每个数值是属于unique_labels中第几类
# counts统计unique_labels中每个数值在correct_label中出现了几次
unique_labels, unique_id, counts = tf.unique_with_counts(correct_label)
counts = tf.cast(counts, tf.float32)
num_instances = tf.size(unique_labels)

# 计算pixel embedding均值向量
# segmented_sum是把reshaped_pred中对应unique_id不同类别的数字相加
# 比如unique_id[0, 0, 1, 1, 0],reshaped_pred[1, 2, 3, 4, 5],最后等于[1+2+5,3+4],channel层不相加
segmented_sum = tf.unsorted_segment_sum(
reshaped_pred, unique_id, num_instances)
# 除以每个类别的像素在gt中出现的次数
mu = tf.div(segmented_sum, tf.reshape(counts, (-1, 1)))
# 然后再还原为原图的形式
mu_expand = tf.gather(mu, unique_id)

# 计算公式的loss(var)
Expand Down
60 changes: 36 additions & 24 deletions lanenet_model/lanenet_merge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,17 @@ def _build_model(self, input_tensor, name):
segStage4 = enet_stage.ENet_stage4(segStage3, pooling_indices_2, inputs_shape_2, stage1, isTraining=self._phase)
segStage5 = enet_stage.ENet_stage5(segStage4, pooling_indices_1, inputs_shape_1, initial, isTraining=self._phase)
segLogits = tf.layers.conv2d_transpose(segStage5, 2, [2, 2], strides=2, padding='same', name='fullconv')
segProbabilities = tf.nn.softmax(segLogits, name='logits_to_softmax')

# Embedding branch
with tf.variable_scope('LaneNetEm'):
emStage3 = enet_stage.ENet_stage3(stage2, isTraining=self._phase)
emStage4 = enet_stage.ENet_stage4(emStage3, pooling_indices_2, inputs_shape_2, stage1, isTraining=self._phase)
emStage5 = enet_stage.ENet_stage5(emStage4, pooling_indices_1, inputs_shape_1, initial, isTraining=self._phase)
emLogits = tf.layers.conv2d_transpose(emStage5, 4, [2, 2], strides=2, padding='same', name='fullconv')
emProbabilities = tf.nn.softmax(emLogits, name='logits_to_softmax')

ret = {
'seglogits': segLogits,
'segProbabilities': segProbabilities,
'emLogits': emLogits,
'emProbabilities': emProbabilities
'logits': segLogits,
'deconv': emLogits
}

return ret
Expand Down Expand Up @@ -126,26 +122,41 @@ def compute_loss(self, input_tensor, binary_label, instance_label, name):
with tf.variable_scope(name):
# 前向传播获取logits
inference_ret = self._build_model(input_tensor=input_tensor, name='inference')
# 计算二值分割损失函数
decode_logits = inference_ret['logits']
binary_label_plain = tf.reshape(
binary_label,
shape=[binary_label.get_shape().as_list()[0] *
binary_label.get_shape().as_list()[1] *
binary_label.get_shape().as_list()[2]])
# 加入class weights
unique_labels, unique_id, counts = tf.unique_with_counts(binary_label_plain)
counts = tf.cast(counts, tf.float32)
inverse_weights = tf.divide(1.0,
tf.log(tf.add(tf.divide(tf.constant(1.0), counts),
tf.constant(1.02))))
inverse_weights = tf.gather(inverse_weights, binary_label)
binary_segmenatation_loss = tf.losses.sparse_softmax_cross_entropy(
labels=binary_label, logits=decode_logits, weights=inverse_weights)
binary_segmenatation_loss = tf.reduce_mean(binary_segmenatation_loss)
decode_deconv = inference_ret['deconv']

if self._net_flag.lower() == 'enet':
# 加入bounded inverse class weights
inverse_class_weights = tf.divide(1.0,
tf.log(tf.add(tf.constant(1.02, tf.float32),
tf.nn.softmax(decode_logits))))
decode_logits_weighted = tf.multiply(decode_logits, inverse_class_weights)

loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=decode_logits_weighted, labels=tf.squeeze(binary_label, squeeze_dims=[3]),
name='entropy_loss')

binary_segmenatation_loss = tf.reduce_mean(loss)

else:
# 计算二值分割损失函数
binary_label_plain = tf.reshape(
binary_label,
shape=[binary_label.get_shape().as_list()[0] *
binary_label.get_shape().as_list()[1] *
binary_label.get_shape().as_list()[2]])
# 加入class weights
unique_labels, unique_id, counts = tf.unique_with_counts(binary_label_plain)
counts = tf.cast(counts, tf.float32)
inverse_weights = tf.divide(1.0,
tf.log(tf.add(tf.divide(tf.constant(1.0), counts),
tf.constant(1.02))))
inverse_weights = tf.gather(inverse_weights, binary_label)
binary_segmenatation_loss = tf.losses.sparse_softmax_cross_entropy(
labels=binary_label, logits=decode_logits, weights=inverse_weights)
binary_segmenatation_loss = tf.reduce_mean(binary_segmenatation_loss)

# 计算discriminative loss损失函数
decode_deconv = inference_ret['deconv']
# 像素嵌入
pix_embedding = self.conv2d(inputdata=decode_deconv, out_channel=4, kernel_size=1,
use_bias=False, name='pix_embedding_conv')
Expand All @@ -164,7 +175,7 @@ def compute_loss(self, input_tensor, binary_label, instance_label, name):
else:
l2_reg_loss = tf.add(l2_reg_loss, tf.nn.l2_loss(vv))
l2_reg_loss *= 0.001
total_loss = 0.5 * binary_segmenatation_loss + 0.5 * disc_loss + l2_reg_loss
total_loss = 1.0 * binary_segmenatation_loss + 0.5 * disc_loss + l2_reg_loss

ret = {
'total_loss': total_loss,
Expand All @@ -173,6 +184,7 @@ def compute_loss(self, input_tensor, binary_label, instance_label, name):
'binary_seg_loss': binary_segmenatation_loss,
'discriminative_loss': disc_loss
}
total_loss.shape = 1

return ret

Expand Down
File renamed without changes.
34 changes: 17 additions & 17 deletions tools/generate_tusimple_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
for line_index, line in enumerate(file):
info_dict = json.loads(line)

image_dir = ops.split(info_dict['raw_file'])[0]
image_dir_split = image_dir.split('/')[1:]
image_dir_split.append(ops.split(info_dict['raw_file'])[1])
image_name = '_'.join(image_dir_split)
image_path = ops.join(src_dir, info_dict['raw_file'])
assert ops.exists(image_path), '{:s} not exist'.format(image_path)

Expand Down Expand Up @@ -94,7 +90,7 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
cv2.imwrite(dst_instance_image_path, dst_instance_image)
cv2.imwrite(dst_rgb_image_path, src_image)

print('Process {:s} success'.format(image_name))
print('Process {:s} success'.format(image_path))


def gen_train_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir):
Expand Down Expand Up @@ -139,31 +135,35 @@ def process_tusimple_dataset(src_dir):
:param src_dir:
:return:
"""
traing_folder_path = ops.join(src_dir, 'training')
training_folder_path = ops.join(src_dir, 'training')
testing_folder_path = ops.join(src_dir, 'testing')

os.makedirs(traing_folder_path, exist_ok=True)
os.makedirs(testing_folder_path, exist_ok=True)
if not os.path.exists(training_folder_path):
os.makedirs(training_folder_path)
if not os.path.exists(testing_folder_path):
os.makedirs(testing_folder_path)

for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1]

shutil.copyfile(json_label_path, ops.join(traing_folder_path, json_label_name))
shutil.copyfile(json_label_path, ops.join(training_folder_path, json_label_name))

for json_label_path in glob.glob('{:s}/test*.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1]

shutil.copyfile(json_label_path, ops.join(testing_folder_path, json_label_name))

gt_image_dir = ops.join(traing_folder_path, 'gt_image')
gt_binary_dir = ops.join(traing_folder_path, 'gt_binary_image')
gt_instance_dir = ops.join(traing_folder_path, 'gt_instance_image')
gt_image_dir = ops.join(training_folder_path, 'gt_image')
gt_binary_dir = ops.join(training_folder_path, 'gt_binary_image')
gt_instance_dir = ops.join(training_folder_path, 'gt_instance_image')

os.makedirs(gt_image_dir, exist_ok=True)
os.makedirs(gt_binary_dir, exist_ok=True)
os.makedirs(gt_instance_dir, exist_ok=True)
if not os.path.exists(gt_image_dir):
os.makedirs(gt_image_dir)
if not os.path.exists(gt_binary_dir):
os.makedirs(gt_binary_dir)
if not os.path.exists(gt_instance_dir):
os.makedirs(gt_instance_dir)

for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)):
for json_label_path in glob.glob('{:s}/*.json'.format(training_folder_path)):
process_json_file(json_label_path, src_dir, gt_image_dir, gt_binary_dir, gt_instance_dir)

gen_train_sample(src_dir, gt_binary_dir, gt_instance_dir, gt_image_dir)
Expand Down
6 changes: 3 additions & 3 deletions tools/train_lanenet.py → train_lanenet.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
CFG.TRAIN.IMG_WIDTH],
name='instance_input_label')
phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')
phase = tf.placeholder(dtype=tf.bool, shape=None, name='net_phase')

net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase)

Expand Down Expand Up @@ -299,15 +299,15 @@ def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
' mean_cost_time= {:5f}s '.
format(epoch + 1, c, binary_loss, instance_loss, train_accuracy,
np.mean(train_cost_time_mean)))
train_cost_time_mean.clear()
train_cost_time_mean = []

if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0:
log.info('Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
'instance_seg_loss= {:6f} accuracy= {:6f} '
'mean_cost_time= {:5f}s '.
format(epoch + 1, c_val, val_binary_seg_loss, val_instance_seg_loss, val_accuracy,
np.mean(val_cost_time_mean)))
val_cost_time_mean.clear()
val_cost_time_mean = []

if epoch % 2000 == 0:
saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
Expand Down

0 comments on commit 0b4b54e

Please sign in to comment.