forked from zhudelong/ocr-rcnn-v2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
merge two graphs and add optimized graphs
- Loading branch information
deron
committed
Mar 25, 2019
1 parent
d827bef
commit 28b9f51
Showing
9 changed files
with
398 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
#!/usr/bin/env python | ||
import os | ||
import imageio | ||
import numpy as np | ||
import tensorflow as tf | ||
from PIL import Image, ImageDraw, ImageFont | ||
from utils.ops import native_crop_and_resize | ||
from utils import visualization_utils as vis_util | ||
import tensorflow.contrib.tensorrt as trt | ||
|
||
charset = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, | ||
'6': 6, '7': 7, '8': 8, '9': 9, 'A': 10, 'B': 11, | ||
'C': 12, 'D': 13, 'E': 14, 'F': 15, 'G': 16, 'H': 17, | ||
'I': 18, 'J': 19, 'K': 20, 'L': 21, 'M': 22, 'N': 23, | ||
'O': 24, 'P': 25, 'R': 26, 'S': 27, 'T': 28, 'U': 29, | ||
'V': 30, 'X': 31, 'Z': 32, '<': 33, '>': 34, '(': 35, | ||
')': 36, '$': 37, '#': 38, '^': 39, 's': 40, '-': 41, | ||
'*': 42, '%': 43, '?': 44, '!': 45, '+': 46} # <nul> = + | ||
|
||
class ButtonRecognizer: | ||
def __init__(self, rcnn_path= None, ocr_path=None, use_trt=False, precision='FP16', use_optimized=False): | ||
self.ocr_graph_path = ocr_path | ||
self.rcnn_graph_path = rcnn_path | ||
self.use_trt = use_trt | ||
self.precision=precision #'INT8, FP16, FP32' | ||
self.use_optimized = use_optimized | ||
self.session = None | ||
|
||
self.ocr_input = None | ||
self.ocr_output = None | ||
self.rcnn_input = None | ||
self.rcnn_output = None | ||
|
||
self.class_num = 1 | ||
self.image_size = [480, 640] | ||
self.recognition_size = [180, 180] | ||
self.category_index = {1: {'id': 1, 'name': u'button'}} | ||
self.idx_lbl = {} | ||
for key in charset.keys(): | ||
self.idx_lbl[charset[key]] = key | ||
self.load_and_merge_graphs() | ||
print('Button recognizer initialized!') | ||
|
||
def __del__(self): | ||
self.clear_session() | ||
|
||
def optimize_rcnn(self, input_graph_def): | ||
trt_graph = trt.create_inference_graph( | ||
input_graph_def=input_graph_def, | ||
outputs=['detection_boxes', 'detection_scores', 'detection_classes', 'num_detections'], | ||
max_batch_size = 1, | ||
# max_workspace_size_bytes=(2 << 10) << 20, | ||
precision_mode = self.precision) | ||
return trt_graph | ||
|
||
def optimize_ocr(self, input_graph_def): | ||
output_graph_def = trt.create_inference_graph( | ||
input_graph_def = input_graph_def, | ||
outputs = ['predicted_chars', 'predicted_scores'], | ||
max_batch_size = 1, | ||
# max_workspace_size_bytes=(2 << 10) << 20, | ||
precision_mode = self.precision) | ||
return output_graph_def | ||
|
||
def load_and_merge_graphs(self): | ||
# check graph paths | ||
if self.ocr_graph_path is None: | ||
self.ocr_graph_path = './frozen_model/ocr_graph.pb' | ||
if self.rcnn_graph_path is None: | ||
self.rcnn_graph_path = './frozen_model/detection_graph_640x480.pb' | ||
if self.use_optimized: | ||
self.ocr_graph_path.replace('.pb', '_optimized.pb') | ||
self.rcnn_graph_path.replace('.pb', '_optimized.pb') | ||
assert os.path.exists(self.ocr_graph_path) and os.path.exists(self.rcnn_graph_path) | ||
|
||
# merge the frozen graphs | ||
ocr_rcnn_graph = tf.Graph() | ||
with ocr_rcnn_graph.as_default(): | ||
|
||
# load button detection graph definition | ||
with tf.gfile.GFile(self.rcnn_graph_path, 'rb') as fid: | ||
detection_graph_def = tf.GraphDef() | ||
serialized_graph = fid.read() | ||
detection_graph_def.ParseFromString(serialized_graph) | ||
# for node in detection_graph_def.node: | ||
# print node.name | ||
if self.use_trt: | ||
detection_graph_def = self.optimize_rcnn(detection_graph_def) | ||
tf.import_graph_def(detection_graph_def, name='detection') | ||
|
||
# load character recognition graph definition | ||
with tf.gfile.GFile(self.ocr_graph_path, 'rb') as fid: | ||
recognition_graph_def = tf.GraphDef() | ||
serialized_graph = fid.read() | ||
recognition_graph_def.ParseFromString(serialized_graph) | ||
if self.use_trt: | ||
recognition_graph_def = self.optimize_ocr(recognition_graph_def) | ||
tf.import_graph_def(recognition_graph_def, name='recognition') | ||
|
||
# retrive detection tensors | ||
rcnn_input = ocr_rcnn_graph.get_tensor_by_name('detection/image_tensor:0') | ||
rcnn_boxes = ocr_rcnn_graph.get_tensor_by_name('detection/detection_boxes:0') | ||
rcnn_scores = ocr_rcnn_graph.get_tensor_by_name('detection/detection_scores:0') | ||
rcnn_number = ocr_rcnn_graph.get_tensor_by_name('detection/num_detections:0') | ||
|
||
# crop and resize valida boxes (only valid when rcnn input has an known shape) | ||
rcnn_number = tf.to_int32(rcnn_number) | ||
valid_boxes = tf.slice(rcnn_boxes, [0, 0, 0], [1, rcnn_number[0], 4]) | ||
ocr_boxes = native_crop_and_resize(rcnn_input, valid_boxes, self.recognition_size) | ||
|
||
# retrive recognition tensors | ||
ocr_input = ocr_rcnn_graph.get_tensor_by_name('recognition/ocr_input:0') | ||
ocr_chars = ocr_rcnn_graph.get_tensor_by_name('recognition/predicted_chars:0') | ||
ocr_beliefs = ocr_rcnn_graph.get_tensor_by_name('recognition/predicted_scores:0') | ||
|
||
self.rcnn_input = rcnn_input | ||
self.rcnn_output = [rcnn_boxes, rcnn_scores, rcnn_number, ocr_boxes] | ||
self.ocr_input = ocr_input | ||
self.ocr_output = [ocr_chars, ocr_beliefs] | ||
|
||
self.session = tf.Session(graph=ocr_rcnn_graph) | ||
|
||
def clear_session(self): | ||
if self.session is not None: | ||
self.session.close() | ||
|
||
def decode_text(self, codes, scores): | ||
score_ave = 0 | ||
text = '' | ||
for char, score in zip(codes, scores): | ||
if not self.idx_lbl[char] == '+': | ||
score_ave += score | ||
text += self.idx_lbl[char] | ||
score_ave /= len(text) | ||
return text, score_ave | ||
|
||
def predict(self, image_np, draw=False): | ||
# input data | ||
assert image_np.shape == (480, 640, 3) | ||
img_in = np.expand_dims(image_np, axis=0) | ||
|
||
# output data | ||
recognition_list = [] | ||
|
||
# perform detection and recognition | ||
boxes, scores, number, ocr_boxes = self.session.run(self.rcnn_output, feed_dict={self.rcnn_input:img_in}) | ||
boxes, scores, number = [np.squeeze(x) for x in [boxes, scores, number]] | ||
|
||
for i in range(number): | ||
if scores[i] < 0.5: continue | ||
chars, beliefs = self.session.run(self.ocr_output, feed_dict={self.ocr_input: ocr_boxes[:,i]}) | ||
chars, beliefs = [np.squeeze(x) for x in [chars, beliefs]] | ||
text, belief = self.decode_text(chars, beliefs) | ||
recognition_list.append([boxes[i], scores[i], text, belief]) | ||
|
||
if draw: | ||
classes = [1]*len(boxes) | ||
self.draw_detection_result(image_np, boxes, classes, scores, self.category_index) | ||
self.draw_recognition_result(image_np, recognition_list) | ||
|
||
return recognition_list | ||
|
||
@staticmethod | ||
def draw_detection_result(image_np, boxes, classes, scores, category, predict_chars=None): | ||
vis_util.visualize_boxes_and_labels_on_image_array( | ||
image_np, | ||
np.squeeze(boxes), | ||
np.squeeze(classes).astype(np.int32), | ||
np.squeeze(scores), | ||
category, | ||
max_boxes_to_draw=100, | ||
use_normalized_coordinates=True, | ||
line_thickness=5, | ||
predict_chars=predict_chars | ||
) | ||
|
||
def draw_recognition_result(self, image_np, recognitions): | ||
for item in recognitions: | ||
# crop button patches | ||
y_min = int(item[0][0] * self.image_size[0]) | ||
x_min = int(item[0][1] * self.image_size[1]) | ||
y_max = int(item[0][2] * self.image_size[0]) | ||
x_max = int(item[0][3] * self.image_size[1]) | ||
button_patch = image_np[y_min: y_max, x_min: x_max] | ||
# generate image layer for drawing | ||
img_pil = Image.fromarray(button_patch) | ||
img_show = ImageDraw.Draw(img_pil) | ||
# draw at a proper location | ||
x_center = (x_max-x_min) / 2.0 | ||
y_center = (y_max-y_min) / 2.0 | ||
font_size = min(x_center, y_center)*1.1 | ||
text_center = int(x_center-0.5*font_size), int(y_center-0.5*font_size) | ||
font = ImageFont.truetype('/Library/Fonts/Arial.ttf', int(font_size)) | ||
img_show.text(text_center, text=item[2], font=font, fill=(255, 0, 255)) | ||
# img_pil.show() | ||
image_np[y_min: y_max, x_min: x_max] = np.array(img_pil) | ||
|
||
|
||
if __name__ == '__main__': | ||
recognizer = ButtonRecognizer(use_optimized=True) | ||
image = imageio.imread('./test_panels/1.jpg') | ||
recognition_list =recognizer.predict(image,True) | ||
image = Image.fromarray(image) | ||
image.show() | ||
recognizer.clear_session() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/usr/bin/env python | ||
from __future__ import print_function | ||
import os | ||
import cv2 | ||
import imageio | ||
import PIL.Image | ||
import PIL.ImageOps as ImageOps | ||
import numpy as np | ||
import tensorflow as tf | ||
from button_recognition import ButtonRecognizer | ||
|
||
DRAW = False | ||
|
||
def get_image_name_list(target_path): | ||
assert os.path.exists(target_path) | ||
image_name_list = [] | ||
file_set = os.walk(target_path) | ||
for root, dirs, files in file_set: | ||
for image_name in files: | ||
image_name_list.append(image_name.split('.')[0]) | ||
return image_name_list | ||
|
||
def warm_up(model): | ||
assert isinstance(model, ButtonRecognizer) | ||
image = imageio.imread('./test_panels/1.jpg') | ||
model.predict(image) | ||
|
||
if __name__ == '__main__': | ||
data_dir = './test_panels' | ||
data_list = get_image_name_list(data_dir) | ||
recognizer = ButtonRecognizer(use_optimized=True) | ||
warm_up(recognizer) | ||
overall_time = 0 | ||
for data in data_list: | ||
img_path = os.path.join(data_dir, data+'.jpg') | ||
image = PIL.Image.open(tf.gfile.GFile(img_path)) | ||
# resize to 640x480 with ratio kept | ||
img_thumbnail = image.thumbnail((640, 480), PIL.Image.ANTIALIAS) | ||
delta_w, delta_h= 640 - image.size[0], 480 - image.size[1] | ||
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) | ||
new_im = ImageOps.expand(image, padding) | ||
img_np = np.copy(np.asarray(new_im)) | ||
# perform button recognition | ||
t0 = cv2.getTickCount() | ||
recognizer.predict(img_np, draw=DRAW) | ||
t1 = cv2.getTickCount() | ||
time = (t1-t0)/cv2.getTickFrequency() | ||
overall_time += time | ||
print('Time elapsed: {}'.format(time)) | ||
if DRAW: | ||
image = PIL.Image.fromarray(img_np) | ||
image.show() | ||
|
||
average_time = overall_time / len(data_list) | ||
print('Average_used: {}'.format(average_time)) | ||
recognizer.clear_session() |
Oops, something went wrong.