Skip to content

Commit

Permalink
Merge pull request pqpo#84 from pqpo/feature/hed_net
Browse files Browse the repository at this point in the history
Feature/hed net
  • Loading branch information
pqpo authored Aug 1, 2019
2 parents a2f6477 + 4077048 commit 4feeba6
Show file tree
Hide file tree
Showing 16 changed files with 706 additions and 0 deletions.
12 changes: 12 additions & 0 deletions edge_detection/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/python
#coding=utf8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

image_height = 256
image_width = 256
use_batch_norm = True
use_kernel_regularizer = False

Binary file added edge_detection/const.pyc
Binary file not shown.
68 changes: 68 additions & 0 deletions edge_detection/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/python
# coding=utf8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from util import *
from hed_net import *
import os

from tensorflow import flags

flags.DEFINE_string('input_img', 'test_image/test2.jpg',
'Image path to run hed, must be jpg image.')
flags.DEFINE_string('checkpoint_dir', './checkpoint',
'Checkpoint directory.')
flags.DEFINE_string('output_img', 'test_image/result.jpg',
'Output image path.')
flags.DEFINE_float('output_threshold', 0.0, 'output threshold, default: 0.0')

FLAGS = flags.FLAGS

if not os.path.exists(FLAGS.input_img):
print('--input_img invalid')
exit()

if FLAGS.output_img == '':
print('--output_img invalid')
exit()

if __name__ == "__main__":
image_path_placeholder = tf.placeholder(tf.string)

feed_dict_to_use = {image_path_placeholder: FLAGS.input_img}

image_tensor = tf.read_file(image_path_placeholder)
image_tensor = tf.image.decode_jpeg(image_tensor, channels=3)
origin_tensor = tf.image.resize_images(image_tensor, [const.image_height, const.image_width])
image_float = tf.to_float(origin_tensor)
image_float = image_float / 255.0
image_float = tf.expand_dims(image_float, axis=0)

dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5 = mobilenet_v2_style_hed(image_float, False)
# dsn_fuse = tf.reshape(dsn_fuse, shape=(const.image_height, const.image_width))

global_init = tf.global_variables_initializer()

# Saver
hed_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='hed')
saver = tf.train.Saver(hed_weights)

with tf.Session() as sess:
sess.run(global_init)

latest_ck_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if latest_ck_file:
print('restore from latest checkpoint file : {}'.format(latest_ck_file))
saver.restore(sess, latest_ck_file)
else:
print('no checkpoint file to restore, exit()')
exit()

_dsn_fuse = sess.run(dsn_fuse, feed_dict=feed_dict_to_use)

dsn_fuse_image = np.where(_dsn_fuse[0] > FLAGS.output_threshold, [255], [0])
save_img(FLAGS.output_img, dsn_fuse_image.reshape([256, 256]))
print('done! output image: {}'.format(FLAGS.output_img))
117 changes: 117 additions & 0 deletions edge_detection/finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/python
# coding=utf8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from hed_net import *
import numpy as np
import os
import util
from generate_batch_data import generate_batch_data

from tensorflow import flags

flags.DEFINE_string('finetuning_dir', './finetuning_model',
'finetuning directory.')
flags.DEFINE_string('checkpoint_dir', './checkpoint',
'Checkpoint directory.')
flags.DEFINE_string('image', 'test_image/IMG_20190704_143127.jpg', 'fine tuning image')
flags.DEFINE_string('annotation', 'test_image/annotation_IMG_20190704_143127.png_threshold.jpg', 'fine tuning annotation')
flags.DEFINE_string('csv', '', 'fine tuning csv')
flags.DEFINE_integer('batch_size', 4, 'batch size')
flags.DEFINE_float('lr', 0.0005, 'learning rate')
flags.DEFINE_integer('iterations', 15,
'Number of iterations')
flags.DEFINE_float('output_threshold', 0.0, 'output threshold')

FLAGS = flags.FLAGS

hed_ckpt_file_path = os.path.join(FLAGS.checkpoint_dir, 'hed.ckpt')

train_layer = ['block0_1', 'block1_0', 'block2_1', 'block3_2', 'block4_3', 'block5_2']

if not ((os.path.exists(FLAGS.image)) and (os.path.exists(FLAGS.annotation)) or (os.path.exists(FLAGS.csv))):
print('please add input, --img, --annotation or --csv')
exit()

images = []
annotations = []

batch_size = FLAGS.batch_size

if os.path.exists(FLAGS.image) and os.path.exists(FLAGS.annotation):
images.append(FLAGS.image)
annotations.append(FLAGS.annotation)

if os.path.exists(FLAGS.csv):
csv_img, csv_ann = util.load_sample_from_csv(FLAGS.csv)
if csv_img is not None and len(csv_img) > 0:
images.extend(csv_img)
annotations.extend(csv_ann)

if len(images) == 0:
print('Samples is empty, exit()')
exit()

print('fine tuning images size: {}'.format(len(images)))
print('fine tuning annotations size: {}'.format(len(annotations)))

assert len(images) == len(annotations)

image_tensor, annotation_tensor = generate_batch_data(images, annotations, batch_size=batch_size)

if __name__ == "__main__":

is_training = tf.placeholder(tf.bool)

dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5 = mobilenet_v2_style_hed(image_tensor, is_training)

hed_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='hed')

cost = class_balanced_sigmoid_cross_entropy(dsn_fuse, annotation_tensor)
# cost = class_balanced_sigmoid_cross_entropy(dsn_fuse, annotation_tensor) \
# + class_balanced_sigmoid_cross_entropy(dsn1, annotation_tensor)\
# + class_balanced_sigmoid_cross_entropy(dsn2, annotation_tensor)\
# + class_balanced_sigmoid_cross_entropy(dsn3, annotation_tensor)\
# + class_balanced_sigmoid_cross_entropy(dsn4, annotation_tensor)\
# + class_balanced_sigmoid_cross_entropy(dsn5, annotation_tensor)

var_list = [v for v in tf.trainable_variables() if v.name.split('/')[2] in train_layer]
gradients = tf.gradients(cost, var_list)
gradients = list(zip(gradients, var_list))

with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
train_step = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).apply_gradients(gradients)
# train_step = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).minimize(cost)

global_init = tf.global_variables_initializer()

# Saver
saver = tf.train.Saver(hed_weights)

with tf.Session() as sess:
sess.run(global_init)

latest_ck_file = tf.train.latest_checkpoint(FLAGS.finetuning_dir)
if latest_ck_file:
print('restore from latest checkpoint file : {}'.format(latest_ck_file))
saver.restore(sess, latest_ck_file)
else:
print('no checkpoint file to restore, exit()')
exit()

for epoch in range(FLAGS.iterations):
for step in range(len(images)):
feed_dict_to_use = {is_training: True}
_dsn_fuse, _ = sess.run([dsn_fuse, train_step], feed_dict=feed_dict_to_use)
if epoch == FLAGS.iterations - 1:
feed_dict_to_use[is_training] = False
dsn_fuse_evaluate = sess.run(dsn_fuse, feed_dict=feed_dict_to_use)
dsn_fuse_image = np.where(dsn_fuse_evaluate[0] > FLAGS.output_threshold, [255], [0])
dsn_fuse_image_path = os.path.join('./test_image', 'fine_tuning_output_img.png')
util.save_img(dsn_fuse_image_path, dsn_fuse_image.reshape([256, 256]))
saver.save(sess, hed_ckpt_file_path, global_step=0)

print("Train Finished!")
2 changes: 2 additions & 0 deletions edge_detection/finetuning_model/checkpoint
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_checkpoint_path: "hed.ckpt-finetuning"
all_model_checkpoint_paths: "hed.ckpt-finetuning"
Binary file not shown.
Binary file not shown.
Binary file not shown.
51 changes: 51 additions & 0 deletions edge_detection/freeze_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/python
#coding=utf8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from hed_net import *

from tensorflow import flags
flags.DEFINE_string('checkpoint_dir', './checkpoint',
'Checkpoint directory.')
flags.DEFINE_string('output_file', './hed_lite_model_quantize.tflite',
'Output file')

FLAGS = flags.FLAGS

if __name__ == "__main__":

image_input = tf.placeholder(tf.float32, shape=(const.image_height, const.image_width, 3), name='hed_input')
image_float = image_input / 255.0
image_float = tf.expand_dims(image_float, axis=0)

print('###1 input shape is: {}, name is: {}'.format(image_input.get_shape(), image_input.name))
dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5 = mobilenet_v2_style_hed(image_float, False)
img_output = tf.reshape(dsn_fuse, shape=(const.image_height, const.image_width), name="img_output")
print('###2 output shape is: {}, name is: {}'.format(img_output.get_shape(), img_output.name))

# Saver
hed_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='hed')
saver = tf.train.Saver(hed_weights)

global_init = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(global_init)

latest_ck_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if latest_ck_file:
print('restore from latest checkpoint file : {}'.format(latest_ck_file))
saver.restore(sess, latest_ck_file)
else:
print('no checkpoint file to restore, exit()')
exit()

converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [image_input], [img_output])
converter.post_training_quantize = True
tflite_model = converter.convert()
open(FLAGS.output_file, 'wb').write(tflite_model)
print('finished')

37 changes: 37 additions & 0 deletions edge_detection/generate_batch_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/python
#coding=utf8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import multiprocessing as mt

import tensorflow as tf
import const


def generate_batch_data(images, annotations, batch_size):

def __map_fun(image_path, annotation_path):
image_tensor = tf.read_file(image_path)
image_tensor = tf.image.decode_jpeg(image_tensor, channels=3)
image_tensor = tf.image.resize_images(image_tensor, [const.image_height, const.image_width])
image_float = tf.to_float(image_tensor)
image_float = image_float / 255.0

annotation_content = tf.read_file(annotation_path)
annotation_tensor = tf.image.decode_png(annotation_content, channels=1)
annotation_tensor = tf.image.resize_images(annotation_tensor, [const.image_height, const.image_width])
annotation_float = tf.to_float(annotation_tensor)
annotation_float = annotation_float / 255.0

return image_float, annotation_float

data_set = tf.data.Dataset.from_tensor_slices((images, annotations))\
.shuffle(100).repeat().map(__map_fun, num_parallel_calls=mt.cpu_count()).batch(batch_size)
iterator = data_set.make_one_shot_iterator()
return iterator.get_next()



Loading

0 comments on commit 4feeba6

Please sign in to comment.