Skip to content

Commit

Permalink
refactor: add loading of the model, saved as protobuf, and make predi…
Browse files Browse the repository at this point in the history
…ction on single image, using it
  • Loading branch information
vbezgachev committed Nov 18, 2017
1 parent 4c35614 commit 82b11e7
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion svnh_semi_supervised_model_loaded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def load_test_images():
return test_images, test_labels


def main(_):
def load_and_predict_with_checkpoints():
'''
Loads saved model checkpoints and make prediction on test images
'''
# load test images and labels
test_images, test_labels = load_test_images()

Expand Down Expand Up @@ -59,5 +62,37 @@ def main(_):
print("Predicted classes: {}".format(pred_class))


def load_and_predict_with_saved_model():
'''
Loads saved as protobuf model and make prediction on a single image
'''
with tf.Session(graph=tf.Graph()) as sess:
# restore save model
export_dir = './gan-export/1'
model = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
# print(model)
loaded_graph = tf.get_default_graph()

# get necessary tensors by name
input_tensor_name = model.signature_def['predict_images'].inputs['images'].name
input_tensor = loaded_graph.get_tensor_by_name(input_tensor_name)
output_tensor_name = model.signature_def['predict_images'].outputs['scores'].name
output_tensor = loaded_graph.get_tensor_by_name(output_tensor_name)

# make prediction
image_file_name = './svnh_test_images/image_3.jpg'
with open(image_file_name, 'rb') as f:
image = f.read()
scores = sess.run(output_tensor, {input_tensor: [image]})

# print results
print("Scores: {}".format(scores))


def main(_):
# load_and_predict_with_checkpoints()
load_and_predict_with_saved_model()


if __name__ == '__main__':
tf.app.run()

0 comments on commit 82b11e7

Please sign in to comment.