-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
executable file
·121 lines (105 loc) · 5.17 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# ==============================================================================
# Copyright 2017 Robert Cottrell. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import input
import model
import os
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('logdir', 'logs', 'Directory to restore checkpoints.')
tf.app.flags.DEFINE_string('predictdir', 'predict/', 'Directory with images to predict labels for')
def run_predictions():
"""Predict labels for new images."""
with tf.Graph().as_default():
# Create images pipeline.
path = tf.placeholder(tf.string)
image = tf.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_image_with_crop_or_pad(image, 54, 54)
image = tf.image.per_image_standardization(image)
# Build a Graph that computes predictions from the inference model.
logits = model.inference(image, 1.0)
# The Op to return predictions.
predict = model.predict(logits)
# The op for initializing the variables.
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
# Create a saver to restore the model from the latest checkpoint.
saver = tf.train.Saver()
with tf.Session() as sess:
# Initialize the session.
sess.run(init_op)
# Restore the latest model snapshot.
ckpt = tf.train.get_checkpoint_state(FLAGS.logdir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print('Model restored: {}'.format(ckpt.model_checkpoint_path))
else:
print('No checkpoint file found!')
return
with open('predict.csv', 'w') as csv:
# Loop through files in the predict directory.
csv.write('image,label,label_pct,counter,counter_pct,y1,y1_pct,y2,y2_pct,y3,y3_pct,y4,y4_pct,y5,y5_pct\n')
for filename in os.listdir(FLAGS.predictdir):
file = FLAGS.predictdir + filename
run_ops = [
predict[0], predict[1], predict[2], predict[3], predict[4], predict[5]
]
p0, p1, p2, p3, p4, p5 = sess.run(run_ops, feed_dict={path: file})
i0 = np.argmax(p0)
i1 = np.argmax(p1)
i2 = np.argmax(p2)
i3 = np.argmax(p3)
i4 = np.argmax(p4)
i5 = np.argmax(p5)
if i0 == 0:
label = 'X'
label_pct = np.exp(p0[0][i0])
if i0 == 1:
label = '{}'.format(p1[i1])
label_pct = np.exp(p0[0][i0] + p1[0][i1])
if i0 == 2:
label = '{}{}'.format(i1, i2)
label_pct = np.exp(p0[0][i0] + p1[0][i1] + p2[0][i2])
if i0 == 3:
label = '{}{}{}'.format(i1, i2, i3)
label_pct = np.exp(p0[0][i0] + p1[0][i1] + p2[0][i2] + p3[0][i3])
if i0 == 4:
label = '{}{}{}{}'.format(i1, i2, i3, i4)
label_pct = np.exp(p0[0][i0] + p1[0][i1] + p2[0][i2] + p3[0][i3] + p4[0][i4])
if i0 == 5:
label = '{}{}{}{}{}'.format(i1, i2, i3, i4, i5)
label_pct = np.exp(p0[0][i0] + p1[0][i1] + p2[0][i2] + p3[0][i3] + p4[0][i4] + p5[0][i5])
if i0 == 6:
label = '+'
label_pct = np.exp(p0[0][i0])
csv.write(','.join([
file,
label, '%0.1f' % (100 * label_pct),
str(i0), '%0.1f' % (100 * np.exp(p0[0][i0])),
str(i1), '%0.1f' % (100 * np.exp(p1[0][i1])),
str(i2), '%0.1f' % (100 * np.exp(p2[0][i2])),
str(i3), '%0.1f' % (100 * np.exp(p3[0][i3])),
str(i4), '%0.1f' % (100 * np.exp(p4[0][i4])),
str(i5), '%0.1f' % (100 * np.exp(p5[0][i5])),
]))
csv.write('\n')
def main(argv=None):
run_predictions()
if __name__ == "__main__":
tf.app.run()