Skip to content

Commit

Permalink
优化代码,修复一些bug,新增python版本nms
Browse files Browse the repository at this point in the history
  • Loading branch information
wenlihaoyu committed Feb 26, 2019
1 parent 92779dd commit 807f882
Show file tree
Hide file tree
Showing 229 changed files with 985 additions and 24,523 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


## 环境部署

GPU部署 参考:setup.md
GPU部署 参考:setup-cpu.md

Expand Down
48 changes: 47 additions & 1 deletion apphelper/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

def sort_box_(box):
x1,y1,x2,y2,x3,y3,x4,y4 = box[:8]
pts = (x1,y1),(x2,y2),(x3,y3),(x4,y4)
pts = np.array(pts, dtype="float32")
(x1,y1),(x2,y2),(x3,y3),(x4,y4) = _order_points(pts)
"""
newBox = [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
## sort x
newBox = sorted(newBox,key=lambda x:x[0])
Expand All @@ -35,8 +39,42 @@ def sort_box_(box):
newBox = sorted(newBox,key=lambda x:-x[1])
x3,y3 = sorted(newBox[:2],key=lambda x:x[0])[0]
"""
return x1,y1,x2,y2,x3,y3,x4,y4


import numpy as np
from scipy.spatial import distance as dist
def _order_points(pts):
# 根据x坐标对点进行排序
"""
---------------------
作者:Tong_T
来源:CSDN
原文:https://blog.csdn.net/Tong_T/article/details/81907132
版权声明:本文为博主原创文章,转载请附上博文链接!
"""
x_sorted = pts[np.argsort(pts[:, 0]), :]

# 从排序中获取最左侧和最右侧的点
# x坐标点
left_most = x_sorted[:2, :]
right_most = x_sorted[2:, :]

# 现在,根据它们的y坐标对最左边的坐标进行排序,这样我们就可以分别抓住左上角和左下角
left_most = left_most[np.argsort(left_most[:, 1]), :]
(tl, bl) = left_most

# 现在我们有了左上角坐标,用它作为锚来计算左上角和右上角之间的欧氏距离;
# 根据毕达哥拉斯定理,距离最大的点将是我们的右下角
distance = dist.cdist(tl[np.newaxis], right_most, "euclidean")[0]
(br, tr) = right_most[np.argsort(distance)[::-1], :]

# 返回左上角,右上角,右下角和左下角的坐标
return np.array([tl, tr, br, bl], dtype="float32")



def solve(box):
"""
绕 cx,cy点 w,h 旋转 angle 的坐标
Expand Down Expand Up @@ -96,10 +134,18 @@ def read_voc_xml(p):
continue

angle = np.float(angle)

if abs(angle)>np.pi/2:
w,h = h,w
angle = abs(angle)%(np.pi/2)*np.sign(angle)

x1,y1,x2,y2,x3,y3,x4,y4 = xy_rotate_box(cx,cy,w,h,angle)
if angle>np.pi/4:
x1,y1,x2,y2,x3,y3,x4,y4 = sort_box_([x1,y1,x2,y2,x3,y3,x4,y4])
"""
if abs(angle)>np.pi/2:
##lableImg bug
x1,y1,x2,y2,x3,y3,x4,y4 = sort_box_([x1,y1,x2,y2,x3,y3,x4,y4])
"""
angle,w,h,cx,cy = solve([x1,y1,x2,y2,x3,y3,x4,y4])

else:
Expand Down
1 change: 1 addition & 0 deletions application/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
基于该项目实现财务票据、表格文字的识别(将于近期公布)
46 changes: 34 additions & 12 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,53 @@
import os
opencvFlag = 'keras'##keras,opencv,darknet
########################文字检测########################
##文字检测引擎 keras,opencv,darknet
pwd = os.getcwd()
opencvFlag = 'keras'
IMGSIZE = (608,608)## yolo3 输入图像尺寸
## keras 版本anchors
keras_anchors = '8,9, 8,18, 8,31, 8,59, 8,124, 8,351, 8,509, 8,605, 8,800'
class_names = ['none','text',]
GPU = True##OCR 是否启用GPU
GPUID=0##调用GPU序号
kerasTextModel=os.path.join(pwd,"models","text.h5")##keras版本模型权重文件

############## darknet yolo ##############
darknetRoot = os.path.join(os.path.curdir,"darknet")## yolo 安装目录
pwd = os.getcwd()
yoloCfg = os.path.join(pwd,"models","text.cfg")
yoloCfg = os.path.join(pwd,"models","text.cfg")
yoloWeights = os.path.join(pwd,"models","text.weights")
yoloData = os.path.join(pwd,"models","text.data")
yoloData = os.path.join(pwd,"models","text.data")
############## darknet yolo ##############

########################文字检测########################

## GPU选择及启动GPU序号
GPU = True##OCR 是否启用GPU
GPUID=0##调用GPU序号

kerasTextModel=os.path.join(pwd,"models","text.h5")##keras版本模型
##文字方向检测
## nms选择,支持cython,gpu,python
nmsFlag='gpu'## cython/gpu/python



##vgg文字方向检测模型
DETECTANGLE=True##是否进行文字方向检测
AngleModelPb = os.path.join(pwd,"models","Angle-model.pb")
AngleModelPbtxt = os.path.join(pwd,"models","Angle-model.pbtxt")
IMGSIZE = (608,608)## yolo3 输入图像尺寸


######################OCR模型######################
##是否启用LSTM crnn模型
DETECTANGLE=True##是否进行文字方向检测
LSTMFLAG = True##OCR模型是否调用LSTM层
chinsesModel = True##模型选择 True:中英文模型 False:英文模型
##OCR模型是否调用LSTM层
LSTMFLAG = True
##模型选择 True:中英文模型 False:英文模型

chinsesModel = True

if chinsesModel:
if LSTMFLAG:
ocrModel = os.path.join(pwd,"models","ocr-lstm.pth")
else:
ocrModel = os.path.join(pwd,"models","ocr-dense.pth")
else:
##纯英文模型
LSTMFLAG=True
ocrModel = os.path.join(pwd,"models","ocr-english.pth")
######################OCR模型######################
49 changes: 21 additions & 28 deletions crnn/crnn.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
#coding:utf-8
import sys
sys.path.insert(1, "./crnn")
import torch
import torch.utils.data
from torch.autograd import Variable
from crnn import util
from crnn import dataset
from crnn.models import crnn as crnn
from crnn.network import CRNN
from crnn import keys
from collections import OrderedDict
from config import ocrModel,LSTMFLAG,GPU
from config import chinsesModel
def crnnSource():
if chinsesModel:
alphabet = keys.alphabetChinese
alphabet = keys.alphabetChinese##中英文模型
else:
alphabet = keys.alphabetEnglish
alphabet = keys.alphabetEnglish##英文模型

converter = util.strLabelConverter(alphabet)
if torch.cuda.is_available() and GPU:
model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cuda()##LSTMFLAG=True crnn 否则 dense ocr
model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cuda()##LSTMFLAG=True crnn 否则 dense ocr
else:
model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cpu()
model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=LSTMFLAG).cpu()

state_dict = torch.load(ocrModel,map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
trainWeights = torch.load(ocrModel,map_location=lambda storage, loc: storage)
modelWeights = OrderedDict()
for k, v in trainWeights.items():
name = k.replace('module.','') # remove `module.`
new_state_dict[name] = v
modelWeights[name] = v
# load params

model.load_state_dict(new_state_dict)
model.load_state_dict(modelWeights)
model.eval()

return model,converter

##加载模型
Expand All @@ -41,30 +39,25 @@ def crnnSource():
def crnnOcr(image):
"""
crnn模型,ocr识别
@@model,
@@converter,
@@im
@@text_recs:text box
image:PIL.Image.convert("L")
"""
scale = image.size[1]*1.0 / 32
w = image.size[0] / scale
w = int(w)
#print "im size:{},{}".format(image.size,w)
transformer = dataset.resizeNormalize((w, 32))
if torch.cuda.is_available() and GPU:
image = transformer(image).cuda()
image = transformer(image).cuda()
else:
image = transformer(image).cpu()
image = transformer(image).cpu()

image = image.view(1, *image.size())
image = Variable(image)
image = image.view(1, *image.size())
image = Variable(image)
model.eval()
preds = model(image)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
preds = model(image)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)

return sim_pred

Expand Down
13 changes: 0 additions & 13 deletions crnn/models/utils.py

This file was deleted.

File renamed without changes.
5 changes: 2 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
else:
## keras版本文字检测
from text import keras_detect as detect


print("Text detect engine:{}".format(opencvFlag))

def text_detect(img,
MAX_HORIZONTAL_GAP=30,
Expand All @@ -34,8 +35,6 @@ def text_detect(img,

):
boxes, scores = detect.text_detect(np.array(img))


boxes = np.array(boxes,dtype=np.float32)
scores = np.array(scores,dtype=np.float32)
textdetector = TextDetector(MAX_HORIZONTAL_GAP,MIN_V_OVERLAPS,MIN_SIZE_SIM)
Expand Down
1 change: 1 addition & 0 deletions setup-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ pip install keras==2.1.5 tensorflow==1.8
conda install pytorch torchvision -c pytorch
## linux
## conda install pytorch-cpu torchvision-cpu -c pytorch
## python版本nms无须执行下一步
pushd text/detector/utils && sh make-for-cpu.sh && popd

3 changes: 2 additions & 1 deletion setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ pip install -U pillow -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip install keras==2.1.5 tensorflow==1.8 tensorflow-gpu==1.8
pip install web.py==0.40.dev0
conda install pytorch torchvision -c pytorch
## pip install torch torchvision
## pip install torch torchvision
## python版本nms无须执行下一步
pushd text/detector/utils && sh make.sh && popd


File renamed without changes.
38 changes: 23 additions & 15 deletions test.ipynb

Large diffs are not rendered by default.

46 changes: 34 additions & 12 deletions text/detector/detectors.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,49 @@
#coding:utf-8
import numpy as np
from config import GPUID,GPU
from text.detector.utils.cython_nms import nms as cython_nms
from config import GPUID,GPU,nmsFlag
from text.detector.utils.python_nms import nms as python_nms ##python版本nms

from text.detector.text_proposal_connector import TextProposalConnector

##优先加载编译对GPU编译的gpu_nms 如果不想调用GPU,在程序启动执行os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if GPU:

##优先加载编译对GPU编译的gpu_nms

if nmsFlag=='gpu' and GPU and GPUID is not None:
try:
from text.detector.utils.gpu_nms import gpu_nms
except:
gpu_nms = None
cython_nms = None

elif nmsFlag=='python':
gpu_nms ==None
cython_nms = None

elif nmsFlag=='cython':
try:
from detector.utils.gpu_nms import gpu_nms
from text.detector.utils.cython_nms import nms as cython_nms
except:
gpu_nms =cython_nms
cython_nms = None
gpu_nms ==None
else:
gpu_nms =None
cython_nms = None

print("Nms engine gpu_nms:",gpu_nms,",cython_nms:",cython_nms,",python_nms:",python_nms)


def nms(dets, thresh):
if dets.shape[0] == 0:
return []

try:
if GPU and GPUID is not None:
return gpu_nms(dets, thresh, device_id=GPUID)
except:
pass
if gpu_nms is not None and GPUID is not None:
return gpu_nms(dets, thresh, device_id=GPUID)

elif cython_nms is not None:
return cython_nms(dets, thresh)
else:
return python_nms(dets, thresh, method='Union')

return cython_nms(dets, thresh)

def normalize(data):
if data.shape[0]==0:
Expand Down
7 changes: 1 addition & 6 deletions text/detector/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
from . import cython_nms
try:
from . import gpu_nms
except:
gpu_nms = cython_nms


Loading

0 comments on commit 807f882

Please sign in to comment.