Skip to content

Commit

Permalink
Merge pull request NVIDIA#1907 from IsaacYangSLA/fix/siamese_tf
Browse files Browse the repository at this point in the history
Fix output_tensor in non-classification task
  • Loading branch information
IsaacYangSLA authored Dec 8, 2017
2 parents 531d0a2 + 433dd26 commit 6cfaef7
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions digits/tools/tensorflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,9 @@ def main(_):
Validation(sess, val_model, current_epoch)

if FLAGS.train_db:
output_tensor = train_model.towers[0].inference
out_name, _ = output_tensor.name.split(':')
if FLAGS.labels_list:
output_tensor = train_model.towers[0].inference
out_name, _ = output_tensor.name.split(':')

if FLAGS.train_db:
del train_model
Expand All @@ -718,22 +719,21 @@ def main(_):

del sess
if FLAGS.train_db:
path_frozen = os.path.join(FLAGS.save, 'frozen_model.pb')

print('Saving frozen model at path {}'.format(path_frozen))

freeze_graph.freeze_graph(
input_graph=graphdef_path,
input_saver='',
input_binary=True,
input_checkpoint=checkpoint_path,
output_node_names=out_name,
restore_op_name="save/restore_all",
filename_tensor_name="save/Const:0",
output_graph=path_frozen,
clear_devices=True,
initializer_nodes="",
)
if FLAGS.labels_list:
path_frozen = os.path.join(FLAGS.save, 'frozen_model.pb')
print('Saving frozen model at path {}'.format(path_frozen))
freeze_graph.freeze_graph(
input_graph=graphdef_path,
input_saver='',
input_binary=True,
input_checkpoint=checkpoint_path,
output_node_names=out_name,
restore_op_name="save/restore_all",
filename_tensor_name="save/Const:0",
output_graph=path_frozen,
clear_devices=True,
initializer_nodes="",
)

logging.info('END')

Expand Down

0 comments on commit 6cfaef7

Please sign in to comment.