forked from pqpo/SmartCropper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request pqpo#84 from pqpo/feature/hed_net
Feature/hed net
- Loading branch information
Showing
16 changed files
with
706 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/usr/bin/python | ||
#coding=utf8 | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
image_height = 256 | ||
image_width = 256 | ||
use_batch_norm = True | ||
use_kernel_regularizer = False | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#!/usr/bin/python | ||
# coding=utf8 | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from util import * | ||
from hed_net import * | ||
import os | ||
|
||
from tensorflow import flags | ||
|
||
flags.DEFINE_string('input_img', 'test_image/test2.jpg', | ||
'Image path to run hed, must be jpg image.') | ||
flags.DEFINE_string('checkpoint_dir', './checkpoint', | ||
'Checkpoint directory.') | ||
flags.DEFINE_string('output_img', 'test_image/result.jpg', | ||
'Output image path.') | ||
flags.DEFINE_float('output_threshold', 0.0, 'output threshold, default: 0.0') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
if not os.path.exists(FLAGS.input_img): | ||
print('--input_img invalid') | ||
exit() | ||
|
||
if FLAGS.output_img == '': | ||
print('--output_img invalid') | ||
exit() | ||
|
||
if __name__ == "__main__": | ||
image_path_placeholder = tf.placeholder(tf.string) | ||
|
||
feed_dict_to_use = {image_path_placeholder: FLAGS.input_img} | ||
|
||
image_tensor = tf.read_file(image_path_placeholder) | ||
image_tensor = tf.image.decode_jpeg(image_tensor, channels=3) | ||
origin_tensor = tf.image.resize_images(image_tensor, [const.image_height, const.image_width]) | ||
image_float = tf.to_float(origin_tensor) | ||
image_float = image_float / 255.0 | ||
image_float = tf.expand_dims(image_float, axis=0) | ||
|
||
dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5 = mobilenet_v2_style_hed(image_float, False) | ||
# dsn_fuse = tf.reshape(dsn_fuse, shape=(const.image_height, const.image_width)) | ||
|
||
global_init = tf.global_variables_initializer() | ||
|
||
# Saver | ||
hed_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='hed') | ||
saver = tf.train.Saver(hed_weights) | ||
|
||
with tf.Session() as sess: | ||
sess.run(global_init) | ||
|
||
latest_ck_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) | ||
if latest_ck_file: | ||
print('restore from latest checkpoint file : {}'.format(latest_ck_file)) | ||
saver.restore(sess, latest_ck_file) | ||
else: | ||
print('no checkpoint file to restore, exit()') | ||
exit() | ||
|
||
_dsn_fuse = sess.run(dsn_fuse, feed_dict=feed_dict_to_use) | ||
|
||
dsn_fuse_image = np.where(_dsn_fuse[0] > FLAGS.output_threshold, [255], [0]) | ||
save_img(FLAGS.output_img, dsn_fuse_image.reshape([256, 256])) | ||
print('done! output image: {}'.format(FLAGS.output_img)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#!/usr/bin/python | ||
# coding=utf8 | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from hed_net import * | ||
import numpy as np | ||
import os | ||
import util | ||
from generate_batch_data import generate_batch_data | ||
|
||
from tensorflow import flags | ||
|
||
flags.DEFINE_string('finetuning_dir', './finetuning_model', | ||
'finetuning directory.') | ||
flags.DEFINE_string('checkpoint_dir', './checkpoint', | ||
'Checkpoint directory.') | ||
flags.DEFINE_string('image', 'test_image/IMG_20190704_143127.jpg', 'fine tuning image') | ||
flags.DEFINE_string('annotation', 'test_image/annotation_IMG_20190704_143127.png_threshold.jpg', 'fine tuning annotation') | ||
flags.DEFINE_string('csv', '', 'fine tuning csv') | ||
flags.DEFINE_integer('batch_size', 4, 'batch size') | ||
flags.DEFINE_float('lr', 0.0005, 'learning rate') | ||
flags.DEFINE_integer('iterations', 15, | ||
'Number of iterations') | ||
flags.DEFINE_float('output_threshold', 0.0, 'output threshold') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
hed_ckpt_file_path = os.path.join(FLAGS.checkpoint_dir, 'hed.ckpt') | ||
|
||
train_layer = ['block0_1', 'block1_0', 'block2_1', 'block3_2', 'block4_3', 'block5_2'] | ||
|
||
if not ((os.path.exists(FLAGS.image)) and (os.path.exists(FLAGS.annotation)) or (os.path.exists(FLAGS.csv))): | ||
print('please add input, --img, --annotation or --csv') | ||
exit() | ||
|
||
images = [] | ||
annotations = [] | ||
|
||
batch_size = FLAGS.batch_size | ||
|
||
if os.path.exists(FLAGS.image) and os.path.exists(FLAGS.annotation): | ||
images.append(FLAGS.image) | ||
annotations.append(FLAGS.annotation) | ||
|
||
if os.path.exists(FLAGS.csv): | ||
csv_img, csv_ann = util.load_sample_from_csv(FLAGS.csv) | ||
if csv_img is not None and len(csv_img) > 0: | ||
images.extend(csv_img) | ||
annotations.extend(csv_ann) | ||
|
||
if len(images) == 0: | ||
print('Samples is empty, exit()') | ||
exit() | ||
|
||
print('fine tuning images size: {}'.format(len(images))) | ||
print('fine tuning annotations size: {}'.format(len(annotations))) | ||
|
||
assert len(images) == len(annotations) | ||
|
||
image_tensor, annotation_tensor = generate_batch_data(images, annotations, batch_size=batch_size) | ||
|
||
if __name__ == "__main__": | ||
|
||
is_training = tf.placeholder(tf.bool) | ||
|
||
dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5 = mobilenet_v2_style_hed(image_tensor, is_training) | ||
|
||
hed_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='hed') | ||
|
||
cost = class_balanced_sigmoid_cross_entropy(dsn_fuse, annotation_tensor) | ||
# cost = class_balanced_sigmoid_cross_entropy(dsn_fuse, annotation_tensor) \ | ||
# + class_balanced_sigmoid_cross_entropy(dsn1, annotation_tensor)\ | ||
# + class_balanced_sigmoid_cross_entropy(dsn2, annotation_tensor)\ | ||
# + class_balanced_sigmoid_cross_entropy(dsn3, annotation_tensor)\ | ||
# + class_balanced_sigmoid_cross_entropy(dsn4, annotation_tensor)\ | ||
# + class_balanced_sigmoid_cross_entropy(dsn5, annotation_tensor) | ||
|
||
var_list = [v for v in tf.trainable_variables() if v.name.split('/')[2] in train_layer] | ||
gradients = tf.gradients(cost, var_list) | ||
gradients = list(zip(gradients, var_list)) | ||
|
||
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): | ||
train_step = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).apply_gradients(gradients) | ||
# train_step = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).minimize(cost) | ||
|
||
global_init = tf.global_variables_initializer() | ||
|
||
# Saver | ||
saver = tf.train.Saver(hed_weights) | ||
|
||
with tf.Session() as sess: | ||
sess.run(global_init) | ||
|
||
latest_ck_file = tf.train.latest_checkpoint(FLAGS.finetuning_dir) | ||
if latest_ck_file: | ||
print('restore from latest checkpoint file : {}'.format(latest_ck_file)) | ||
saver.restore(sess, latest_ck_file) | ||
else: | ||
print('no checkpoint file to restore, exit()') | ||
exit() | ||
|
||
for epoch in range(FLAGS.iterations): | ||
for step in range(len(images)): | ||
feed_dict_to_use = {is_training: True} | ||
_dsn_fuse, _ = sess.run([dsn_fuse, train_step], feed_dict=feed_dict_to_use) | ||
if epoch == FLAGS.iterations - 1: | ||
feed_dict_to_use[is_training] = False | ||
dsn_fuse_evaluate = sess.run(dsn_fuse, feed_dict=feed_dict_to_use) | ||
dsn_fuse_image = np.where(dsn_fuse_evaluate[0] > FLAGS.output_threshold, [255], [0]) | ||
dsn_fuse_image_path = os.path.join('./test_image', 'fine_tuning_output_img.png') | ||
util.save_img(dsn_fuse_image_path, dsn_fuse_image.reshape([256, 256])) | ||
saver.save(sess, hed_ckpt_file_path, global_step=0) | ||
|
||
print("Train Finished!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_checkpoint_path: "hed.ckpt-finetuning" | ||
all_model_checkpoint_paths: "hed.ckpt-finetuning" |
Binary file added
BIN
+1.12 MB
edge_detection/finetuning_model/hed.ckpt-finetuning.data-00000-of-00001
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/python | ||
#coding=utf8 | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from hed_net import * | ||
|
||
from tensorflow import flags | ||
flags.DEFINE_string('checkpoint_dir', './checkpoint', | ||
'Checkpoint directory.') | ||
flags.DEFINE_string('output_file', './hed_lite_model_quantize.tflite', | ||
'Output file') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
if __name__ == "__main__": | ||
|
||
image_input = tf.placeholder(tf.float32, shape=(const.image_height, const.image_width, 3), name='hed_input') | ||
image_float = image_input / 255.0 | ||
image_float = tf.expand_dims(image_float, axis=0) | ||
|
||
print('###1 input shape is: {}, name is: {}'.format(image_input.get_shape(), image_input.name)) | ||
dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5 = mobilenet_v2_style_hed(image_float, False) | ||
img_output = tf.reshape(dsn_fuse, shape=(const.image_height, const.image_width), name="img_output") | ||
print('###2 output shape is: {}, name is: {}'.format(img_output.get_shape(), img_output.name)) | ||
|
||
# Saver | ||
hed_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='hed') | ||
saver = tf.train.Saver(hed_weights) | ||
|
||
global_init = tf.global_variables_initializer() | ||
|
||
with tf.Session() as sess: | ||
sess.run(global_init) | ||
|
||
latest_ck_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) | ||
if latest_ck_file: | ||
print('restore from latest checkpoint file : {}'.format(latest_ck_file)) | ||
saver.restore(sess, latest_ck_file) | ||
else: | ||
print('no checkpoint file to restore, exit()') | ||
exit() | ||
|
||
converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [image_input], [img_output]) | ||
converter.post_training_quantize = True | ||
tflite_model = converter.convert() | ||
open(FLAGS.output_file, 'wb').write(tflite_model) | ||
print('finished') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#!/usr/bin/python | ||
#coding=utf8 | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import multiprocessing as mt | ||
|
||
import tensorflow as tf | ||
import const | ||
|
||
|
||
def generate_batch_data(images, annotations, batch_size): | ||
|
||
def __map_fun(image_path, annotation_path): | ||
image_tensor = tf.read_file(image_path) | ||
image_tensor = tf.image.decode_jpeg(image_tensor, channels=3) | ||
image_tensor = tf.image.resize_images(image_tensor, [const.image_height, const.image_width]) | ||
image_float = tf.to_float(image_tensor) | ||
image_float = image_float / 255.0 | ||
|
||
annotation_content = tf.read_file(annotation_path) | ||
annotation_tensor = tf.image.decode_png(annotation_content, channels=1) | ||
annotation_tensor = tf.image.resize_images(annotation_tensor, [const.image_height, const.image_width]) | ||
annotation_float = tf.to_float(annotation_tensor) | ||
annotation_float = annotation_float / 255.0 | ||
|
||
return image_float, annotation_float | ||
|
||
data_set = tf.data.Dataset.from_tensor_slices((images, annotations))\ | ||
.shuffle(100).repeat().map(__map_fun, num_parallel_calls=mt.cpu_count()).batch(batch_size) | ||
iterator = data_set.make_one_shot_iterator() | ||
return iterator.get_next() | ||
|
||
|
||
|
Oops, something went wrong.