Skip to content

Commit

Permalink
优化代码,替换cpython\python\gpunms 为opencv dnn.nms;
Browse files Browse the repository at this point in the history
  • Loading branch information
wenlihaoyu committed Aug 5, 2019
1 parent 9e6f83c commit 6332891
Show file tree
Hide file tree
Showing 275 changed files with 1,031 additions and 52,974 deletions.
Binary file added .gitignore.swp
Binary file not shown.
38 changes: 23 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
- [x] 文字方向检测 0、90、180、270度检测(支持dnn/tensorflow)
- [x] 支持(darknet/opencv dnn /keras)文字检测,支持darknet/keras训练
- [x] 不定长OCR训练(英文、中英文) crnn\dense ocr 识别及训练 ,新增pytorch转keras模型代码(tools/pytorch_to_keras.py)
- [x] 支持darknet 转keras, keras转darknet, pytorch 转keras模型
- [x] 新增对身份证/火车票结构化数据识别
- [ ] 新增语音模型修正OCR识别结果
- [ ] 新增CNN+ctc模型,支持DNN模块调用OCR,单行图像平均时间为0.02秒以下
- [ ] 优化CPU调用,识别速度与GPU接近(近期更新)
- [x] 支持darknet 转keras, keras转darknet, pytorch 转keras模型
- [x] 身份证/火车票结构化数据识别
- [x] 新增CNN+ctc模型,支持DNN模块调用OCR,单行图像平均时间为0.02秒以下
- [ ] CPU版本加速
- [ ] 支持基于用户字典OCR识别
- [ ] 新增语言模型修正OCR识别结果
- [ ] 支持树莓派实时识别方案


## 环境部署

Expand Down Expand Up @@ -36,7 +39,6 @@ lib = CDLL(root+"chineseocr/darknet/libdarknet.so", RTLD_GLOBAL)
## 下载模型文件
模型文件地址:
* [baidu pan](https://pan.baidu.com/s/1gTW9gwJR6hlwTuyB6nCkzQ)
* [google drive](https://drive.google.com/drive/folders/1XiT1FLFvokAdwfE9WSUSS1PnZA34WBzy?usp=sharing)

复制文件夹中的所有文件到models目录

Expand Down Expand Up @@ -65,11 +67,10 @@ pip install .
wget https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
mv zh_giga.no_cna_cmn.prune01244.klm chineseocr/models/
```
## web服务启动
## 模型选择
``` Bash
cd chineseocr## 进入chineseocr目录
ipython app.py 8080 ##8080端口号,可以设置任意端口
```
参考config.py文件
```

## 构建docker镜像
``` Bash
Expand All @@ -81,6 +82,18 @@ docker run -d -p 8080:8080 chineseocr /root/anaconda3/bin/python app.py

```

## web服务启动
``` Bash
cd chineseocr## 进入chineseocr目录
python app.py 8080 ##8080端口号,可以设置任意端口
```

## 访问服务
http://127.0.0.1:8080/ocr

<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/demo.png"/>



## 识别结果展示

Expand All @@ -90,11 +103,6 @@ docker run -d -p 8080:8080 chineseocr /root/anaconda3/bin/python app.py
<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/line-demo.png"/>


## 访问服务
http://127.0.0.1:8080/ocr

<img width="500" height="300" src="https://github.com/chineseocr/chineseocr/blob/master/test/demo.png"/>


## 参考
1. yolo3 https://github.com/pjreddie/darknet.git
Expand Down
125 changes: 102 additions & 23 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,103 @@
@author: lywen
"""
import os
import cv2
import json
import time
import uuid
import base64
import web
import numpy as np
from PIL import Image
web.config.debug = True
import model

render = web.template.render('templates', base='base')
from config import DETECTANGLE
from apphelper.image import union_rbox,adjust_box_to_origin
from config import *
from apphelper.image import union_rbox,adjust_box_to_origin,base64_to_PIL
from application import trainTicket,idcard
if yoloTextFlag =='keras' or AngleModelFlag=='tf' or ocrFlag=='keras':
if GPU:
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPUID)
import tensorflow as tf
from keras import backend as K
config = tf.ConfigProto()
config.gpu_options.allocator_type = 'BFC'
config.gpu_options.per_process_gpu_memory_fraction = 0.3## GPU最大占用量
config.gpu_options.allow_growth = True##GPU是否可动态增加
K.set_session(tf.Session(config=config))
K.get_session().run(tf.global_variables_initializer())

else:
##CPU启动
os.environ["CUDA_VISIBLE_DEVICES"] = ''

if yoloTextFlag=='opencv':
scale,maxScale = IMGSIZE
from text.opencv_dnn_detect import text_detect
elif yoloTextFlag=='darknet':
scale,maxScale = IMGSIZE
from text.darknet_detect import text_detect
elif yoloTextFlag=='keras':
scale,maxScale = IMGSIZE[0],2048
from text.keras_detect import text_detect
else:
print( "err,text engine in keras\opencv\darknet")

from text.opencv_dnn_detect import angle_detect

if ocr_redis:
##多任务并发识别
from apphelper.redisbase import redisDataBase
ocr = redisDataBase().put_values
else:
from crnn.keys import alphabetChinese,alphabetEnglish
if ocrFlag=='keras':
from crnn.network_keras import CRNN
if chineseModel:
alphabet = alphabetChinese
if LSTMFLAG:
ocrModel = ocrModelKerasLstm
else:
ocrModel = ocrModelKerasDense
else:
ocrModel = ocrModelKerasEng
alphabet = alphabetEnglish
LSTMFLAG = True

elif ocrFlag=='torch':
from crnn.network_torch import CRNN
if chineseModel:
alphabet = alphabetChinese
if LSTMFLAG:
ocrModel = ocrModelTorchLstm
else:
ocrModel = ocrModelTorchDense

else:
ocrModel = ocrModelTorchEng
alphabet = alphabetEnglish
LSTMFLAG = True
elif ocrFlag=='opencv':
from crnn.network_dnn import CRNN
ocrModel = ocrModelOpencv
alphabet = alphabetChinese
else:
print( "err,ocr engine in keras\opencv\darknet")

nclass = len(alphabet)+1
if ocrFlag=='opencv':
crnn = CRNN(alphabet=alphabet)
else:
crnn = CRNN( 32, 1, nclass, 256, leakyRelu=False,lstmFlag=LSTMFLAG,GPU=GPU,alphabet=alphabet)
if os.path.exists(ocrModel):
crnn.load_weights(ocrModel)
else:
print("download model or tranform model with tools!")

ocr = crnn.predict_job


from main import TextOcrModel

model = TextOcrModel(ocr,text_detect,angle_detect)


billList = ['通用OCR','火车票','身份证']

Expand All @@ -30,7 +113,6 @@ def GET(self):
post['H'] = 1000
post['width'] = 600
post['W'] = 600
post['uuid'] = uuid.uuid1().__str__()
post['billList'] = billList
return render.ocr(post)

Expand All @@ -42,33 +124,32 @@ def POST(self):
textLine = data.get('textLine',False)##只进行单行识别

imgString = data['imgString'].encode().split(b';base64,')[-1]
imgString = base64.b64decode(imgString)
jobid = uuid.uuid1().__str__()
path = 'test/{}.jpg'.format(jobid)
with open(path,'wb') as f:
f.write(imgString)
img = cv2.imread(path)##GBR
img = base64_to_PIL(imgString)
if img is not None:
img = np.array(img)

H,W = img.shape[:2]
timeTake = time.time()
if textLine:
##单行识别
partImg = Image.fromarray(img)
text = model.crnnOcr(partImg.convert('L'))
text = ocr.predict(partImg.convert('L'))
res =[ {'text':text,'name':'0','box':[0,0,W,0,W,H,0,H]} ]
else:
detectAngle = textAngle
_,result,angle= model.model(img,
result,angle= model.model(img,
scale=scale,
maxScale=maxScale,
detectAngle=detectAngle,##是否进行文字方向检测,通过web传参控制
config=dict(MAX_HORIZONTAL_GAP=50,##字符之间的最大间隔,用于文本行的合并
MAX_HORIZONTAL_GAP=100,##字符之间的最大间隔,用于文本行的合并
MIN_V_OVERLAPS=0.6,
MIN_SIZE_SIM=0.6,
TEXT_PROPOSALS_MIN_SCORE=0.1,
TEXT_PROPOSALS_NMS_THRESH=0.3,
TEXT_LINE_NMS_THRESH = 0.7,##文本行之间测iou值
),
leftAdjust=True,##对检测的文本行进行向左延伸
rightAdjust=True,##对检测的文本行进行向右延伸
alph=0.01,##对检测的文本行进行向右、左延伸的倍数
TEXT_LINE_NMS_THRESH = 0.99,##文本行之间测iou值
LINE_MIN_SCORE=0.1,
leftAdjustAlph=0.01,##对检测的文本行进行向左延伸
rightAdjustAlph=0.01,##对检测的文本行进行向右延伸
)


Expand Down Expand Up @@ -101,8 +182,6 @@ def POST(self):

timeTake = time.time()-timeTake


os.remove(path)
return json.dumps({'res':res,'timeTake':round(timeTake,4)},ensure_ascii=False)


Expand Down
Loading

0 comments on commit 6332891

Please sign in to comment.