forked from chineseocr/chineseocr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTextOcrModel.py
96 lines (80 loc) · 3.34 KB
/
TextOcrModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from text.detector.detectors import TextDetector
from apphelper.image import rotate_cut_img, sort_box
import numpy as np
from PIL import Image
class TextOcrModel(object):
def __init__(self, ocrModel, textModel, angleModel):
self.ocrModel = ocrModel
self.textModel = textModel
self.angleModel = angleModel
def detect_angle(self, img):
"""
detect text angle in [0,90,180,270]
@@img:np.array
"""
angle = self.angleModel(img)
if angle == 90:
im = Image.fromarray(img).transpose(Image.ROTATE_90)
img = np.array(im)
elif angle == 180:
im = Image.fromarray(img).transpose(Image.ROTATE_180)
img = np.array(im)
elif angle == 270:
im = Image.fromarray(img).transpose(Image.ROTATE_270)
img = np.array(im)
return img, angle
def detect_box(self, img, scale=600, maxScale=900):
"""
detect text content
@@img:np.array
"""
boxes, scores = self.textModel(img, scale, maxScale)
return boxes, scores
def box_cluster(self, img, boxes, scores, **args):
MAX_HORIZONTAL_GAP = args.get('MAX_HORIZONTAL_GAP', 100)
MIN_V_OVERLAPS = args.get('MIN_V_OVERLAPS', 0.6)
MIN_SIZE_SIM = args.get('MIN_SIZE_SIM', 0.6)
textdetector = TextDetector(MAX_HORIZONTAL_GAP, MIN_V_OVERLAPS, MIN_SIZE_SIM)
shape = img.shape[:2]
TEXT_PROPOSALS_MIN_SCORE = args.get('TEXT_PROPOSALS_MIN_SCORE', 0.7)
TEXT_PROPOSALS_NMS_THRESH = args.get('TEXT_PROPOSALS_NMS_THRESH', 0.3)
TEXT_LINE_NMS_THRESH = args.get('TEXT_LINE_NMS_THRESH', 0.3)
LINE_MIN_SCORE = args.get('LINE_MIN_SCORE', 0.8)
boxes, scores = textdetector.detect(boxes,
scores[:, np.newaxis],
shape,
TEXT_PROPOSALS_MIN_SCORE,
TEXT_PROPOSALS_NMS_THRESH,
TEXT_LINE_NMS_THRESH,
LINE_MIN_SCORE
)
return boxes, scores
def ocr_batch(self, img, boxes, leftAdjustAlph=0.0, rightAdjustAlph=0.0):
"""
batch for ocr
"""
im = Image.fromarray(img)
newBoxes = []
for index, box in enumerate(boxes):
partImg, box = rotate_cut_img(im, box, leftAdjustAlph, rightAdjustAlph)
box['img'] = partImg.convert('L')
newBoxes.append(box)
res = self.ocrModel(newBoxes)
return res
def model(self, img, **args):
detectAngle = args.get('detectAngle', False)
if detectAngle:
img, angle = self.detect_angle(img)
else:
angle = 0
scale = args.get('scale', 608)
maxScale = args.get('maxScale', 608)
boxes, scores = self.detect_box(img, scale, maxScale) ##文字检测
boxes, scores = self.box_cluster(img, boxes, scores, **args)
boxes = sort_box(boxes)
leftAdjustAlph = args.get('leftAdjustAlph', 0)
rightAdjustAlph = args.get('rightAdjustAlph', 0)
res = self.ocr_batch(img, boxes, leftAdjustAlph, rightAdjustAlph)
return res, angle