Skip to content

Commit

Permalink
随便
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouyufei committed Aug 12, 2021
1 parent 55866df commit 9add799
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ PytorchOCR旨在打造一套训练,推理,部署一体的OCR引擎库
* [x] 服务器端识别模型文件
* [x] DB通用模型
* [ ] 手机端部署
* [ ] With Triton
* [x] With Triton,[推荐使用Savior](https://github.com/novioleo/Savior)

## 环境配置

Expand Down
14 changes: 14 additions & 0 deletions doc/检测+识别推理.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## 检测和识别

1. 检测和识别
```shell script
CUDA_VISIBLE_DEVICES=0 --det_path "" --rec_path "" --img_path ""
```
2.模型运行时间分析
```shell script
CUDA_VISIBLE_DEVICES=0 --det_path "" --rec_path "" --img_path "" -time_profile
```
3.模型运行内存分析
```shell script
CUDA_VISIBLE_DEVICES=0 --det_path "" --rec_path "" --img_path "" -mem_profile
```
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy==1.18.4
Pillow==8.1.1
Pillow==8.2.0
tqdm==4.46.0
opencv-python==4.2.0.34
addict==2.2.1
Expand All @@ -11,3 +11,6 @@ torchvision>=0.8.0
python-Levenshtein>=0.12.0
lmdb>=0.98
imgaug>=0.4.0
line-profiler==3.2.6
memory-profiler==0.58.0

8 changes: 5 additions & 3 deletions tools/det_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, model_path):
self.model.to(self.device)
self.model.eval()
self.resize = ResizeFixedSize(736, False)
self.post_proess = build_post_process(cfg['post_process'])
self.post_process = build_post_process(cfg['post_process'])
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=cfg['dataset']['train']['dataset']['mean'], std=cfg['dataset']['train']['dataset']['std'])
Expand All @@ -45,8 +45,10 @@ def predict(self, img, is_output_polygon=False):
tensor = self.transform(data['img'])
tensor = tensor.unsqueeze(dim=0)
tensor = tensor.to(self.device)
out = self.model(tensor)
box_list, score_list = self.post_proess(out, data['shape'])
with torch.no_grad():
out = self.model(tensor)
out = out.cpu().numpy()
box_list, score_list = self.post_process(out, data['shape'])
box_list, score_list = box_list[0], score_list[0]
if len(box_list) > 0:
idx = [x.sum() > 0 for x in box_list]
Expand Down
113 changes: 113 additions & 0 deletions tools/ocr_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from det_infer import DetInfer
from rec_infer import RecInfer
import argparse
from line_profiler import LineProfiler
from memory_profiler import profile
from torchocr.utils.vis import draw_ocr_box_txt
import numpy as np

def get_rotate_crop_image(img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
points = points.astype(np.float32)
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img


class OCRInfer(object):
def __init__(self, det_path, rec_path, rec_batch_size=16, time_profile=False, mem_profile=False ,**kwargs):
super().__init__()
self.det_model = DetInfer(det_path)
self.rec_model = RecInfer(rec_path, rec_batch_size)
assert not(time_profile and mem_profile),"can not profile memory and time at the same time"
self.line_profiler = None
if time_profile:
self.line_profiler = LineProfiler()
self.predict = self.predict_time_profile
if mem_profile:
self.predict = self.predict_mem_profile

def do_predict(self, img):
box_list, score_list = self.det_model.predict(img)
if len(box_list) == 0:
return [], [], img
draw_box_list = [tuple(map(tuple, box)) for box in box_list]
imgs =[get_rotate_crop_image(img, box) for box in box_list]
texts = self.rec_model.predict(imgs)
texts = [txt[0][0] for txt in texts]
debug_img = draw_ocr_box_txt(img, draw_box_list, texts)
return box_list, score_list, debug_img

def predict(self, img):
return self.do_predict(img)

def predict_mem_profile(self, img):
wapper = profile(self.do_predict)
return wapper(img)

def predict_time_profile(self, img):
# run multi time
for i in range(8):
print("*********** {} profile time *************".format(i))
lp = LineProfiler()
lp_wrapper = lp(self.do_predict)
ret = lp_wrapper(img)
lp.print_stats()
return ret


def init_args():
import argparse
parser = argparse.ArgumentParser(description='OCR infer')
parser.add_argument('--det_path', required=True, type=str, help='det model path')
parser.add_argument('--rec_path', required=True, type=str, help='rec model path')
parser.add_argument('--img_path', required=True, type=str, help='img path for predict')
parser.add_argument('--rec_batch_size', type=int, help='rec batch_size', default=16)
parser.add_argument('-time_profile', action='store_true', help='enable time profile mode')
parser.add_argument('-mem_profile', action='store_true', help='enable memory profile mode')
args = parser.parse_args()
return vars(args)


if __name__ == '__main__':
import cv2
args = init_args()
img = cv2.imread(args['img_path'])
model = OCRInfer(**args)
txts, boxes, debug_img = model.predict(img)
h,w,_, = debug_img.shape
raido = 1
if w > 1200:
raido = 600.0/w
debug_img = cv2.resize(debug_img, (int(w*raido), int(h*raido)))
if not(args['mem_profile'] or args['time_profile']):
cv2.imshow("debug", debug_img)
cv2.waitKey()

38 changes: 27 additions & 11 deletions tools/rec_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

# 将 torchocr路径加到python陆经里
__dir__ = pathlib.Path(os.path.abspath(__file__))

import numpy as np

sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))

Expand All @@ -18,7 +21,7 @@


class RecInfer:
def __init__(self, model_path):
def __init__(self, model_path, batch_size=16):
ckpt = torch.load(model_path, map_location='cpu')
cfg = ckpt['cfg']
self.model = build_model(cfg['model'])
Expand All @@ -33,18 +36,31 @@ def __init__(self, model_path):

self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
self.batch_size = batch_size

def predict(self, img):
def predict(self, imgs):
# 预处理根据训练来
img = self.process.resize_with_specific_height(img)
# img = self.process.width_pad_img(img, 120)
img = self.process.normalize_img(img)
tensor = torch.from_numpy(img.transpose([2, 0, 1])).float()
tensor = tensor.unsqueeze(dim=0)
tensor = tensor.to(self.device)
out = self.model(tensor)
txt = self.converter.decode(out.softmax(dim=2).detach().cpu().numpy())
return txt
if not isinstance(imgs,list):
imgs = [imgs]
imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs]
widths = np.array([img.shape[1] for img in imgs])
idxs = np.argsort(widths)
txts = []
for idx in range(0, len(imgs), self.batch_size):
batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)]
batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs]
batch_imgs = np.stack(batch_imgs)
tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float()
tensor = tensor.to(self.device)
with torch.no_grad():
out = self.model(tensor)
out = out.softmax(dim=2)
out = out.cpu().numpy()
txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out])
#按输入图像的顺序排序
idxs = np.argsort(idxs)
out_txts = [txts[idx] for idx in idxs]
return out_txts


def init_args():
Expand Down
2 changes: 1 addition & 1 deletion torchocr/postprocess/DBPostProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_min_area_bbox(_image, _contour, _scale_ratio=1.0):
Returns: 最小面积矩形的相关信息
"""
h, w = _image.shape[:2]
if _scale_ratio != 1:
if abs(_scale_ratio -1) > 0.001:
reshaped_contour = _contour.reshape(-1, 2)
current_polygon = Polygon(reshaped_contour)
distance = current_polygon.area * _scale_ratio / current_polygon.length
Expand Down
26 changes: 20 additions & 6 deletions torchocr/utils/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
import math
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import os


def draw_ocr_box_txt(image, boxes, txts = None):
def get_font_file():
searchs = ["./doc/田氏颜体大字库2.0.ttf", "../doc/田氏颜体大字库2.0.ttf"]
for path in searchs:
if os.path.exists(path):
return path
assert False,"can't find 田氏颜体大字库2.0.ttf"


def draw_ocr_box_txt(image, boxes, txts = None, pos="horizontal"):
if isinstance(image,np.ndarray):
image = Image.fromarray(image)
h, w = image.height, image.width
Expand All @@ -31,21 +40,26 @@ def draw_ocr_box_txt(image, boxes, txts = None):
box_width = math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10)
font = ImageFont.truetype("./doc/田氏颜体大字库2.0.ttf", font_size, encoding="utf-8")
font = ImageFont.truetype(get_font_file(), font_size, encoding="utf-8")
cur_y = box[0][1]
for c in txt:
char_size = font.getsize(c)
draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
cur_y += char_size[1]
else:
font_size = max(int(box_height * 0.8), 10)
font = ImageFont.truetype("./doc/田氏颜体大字库2.0.ttf", font_size, encoding="utf-8")
font = ImageFont.truetype(get_font_file(), font_size, encoding="utf-8")
draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5)
if txts is not None:
img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(img_right, (w, 0, w * 2, h))
if pos == "horizontal":
img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(img_right, (w, 0, w * 2, h))
else:
img_show = Image.new('RGB', (w, h * 2), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(img_right, (0, h, w , h * 2))
else:
img_show = np.array(img_left)
return np.array(img_show)
Expand Down

0 comments on commit 9add799

Please sign in to comment.