forked from chineseocr/chineseocr
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f63af03
Showing
55 changed files
with
37,704 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
## 本项目基于[yolo3](https://github.com/pjreddie/darknet.git) 与[crnn](https://github.com/meijieru/crnn.pytorch.git) 实现中文自然场景文字检测及识别 | ||
|
||
## 环境部署 | ||
python=3.6 pytorch=0.2.0 | ||
``` Bash | ||
git clone https://github.com/pjreddie/darknet.git | ||
cd darknet | ||
make | ||
cd .. | ||
git clone https://github.com/chineseocr/chineseocr.git | ||
cd chineseocr | ||
sh setup.sh(cpu sh setpu-cpu.sh) | ||
``` | ||
|
||
## 下载模型文件 | ||
|
||
将近期更新 | ||
|
||
``` Bash | ||
mv models.zip chineseocr | ||
unzip models.zip | ||
``` | ||
## web服务启动 | ||
|
||
``` Bash | ||
cd chineseocr## 进入chineseocr目录 | ||
ipython app.py 8080 ##8080端口号,可以设置任意端口 | ||
``` | ||
|
||
## 识别结果展示 | ||
<div> | ||
<img width="300" height="300" src="https://github.com/chineseocr/chinsesocr/blob/master/test/img1.png"/> | ||
<img width="300" height="300" src="https://github.com/chineseocr/chinsesocr/blob/master/test/4.png"/> | ||
</div> | ||
|
||
## 访问服务 | ||
http://127.0.0.1:8080/ocr | ||
|
||
<img width="300" height="300" src="https://github.com/chineseocr/chinsesocr/blob/master/test/demo.png"/> | ||
|
||
|
||
## 参考 | ||
1. yolo3 https://github.com/pjreddie/darknet.git | ||
2. crnn https://github.com/meijieru/crnn.pytorch.git | ||
3. ctpn https://github.com/eragonruan/text-detection-ctpn | ||
4. CTPN https://github.com/tianzhi0549/CTPN | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author: lywen | ||
""" | ||
import os | ||
from PIL import Image | ||
import json | ||
import time | ||
import uuid | ||
import numpy as np | ||
import sys | ||
import base64 | ||
|
||
import web | ||
|
||
web.config.debug = True | ||
from apphelper.image import convert_image,read_url_img,string_to_array,array_to_string,base64_to_array | ||
import model | ||
render = web.template.render('templates', base='base') | ||
|
||
|
||
class OCR: | ||
"""通用OCR识别""" | ||
|
||
def GET(self): | ||
post = {} | ||
post['postName'] = u'ocr'##请求地址 | ||
post['height'] = 1000 | ||
post['H'] = 1000 | ||
post['width'] = 600 | ||
post['W'] = 600 | ||
post['uuid'] = uuid.uuid1().__str__() | ||
return render.ocr(post) | ||
|
||
def POST(self): | ||
data = web.data() | ||
data = json.loads(data) | ||
imgString = data['imgString'].encode().split(b';base64,')[-1] | ||
imgString = base64.b64decode(imgString) | ||
jobid = uuid.uuid1().__str__() | ||
path = '/tmp/{}.jpg'.format(jobid) | ||
with open(path,'wb') as f: | ||
f.write(imgString) | ||
img = Image.open(path).convert("RGB") | ||
W,H = img.size | ||
_,result,angle= model.model(img,detectAngle=True,config=dict(MAX_HORIZONTAL_GAP=200, | ||
MIN_V_OVERLAPS=0.6, | ||
MIN_SIZE_SIM=0.6, | ||
TEXT_PROPOSALS_MIN_SCORE=0.2, | ||
TEXT_PROPOSALS_NMS_THRESH=0.3, | ||
TEXT_LINE_NMS_THRESH = 0.99, | ||
MIN_RATIO=1.0, | ||
LINE_MIN_SCORE=0.2, | ||
TEXT_PROPOSALS_WIDTH=5, | ||
MIN_NUM_PROPOSALS=0, | ||
), | ||
leftAdjust=True,rightAdjust=True,alph=0.1) | ||
|
||
res = map(lambda x:{'w':x['w'],'h':x['h'],'cx':x['cx'],'cy':x['cy'],'degree':x['degree'],'text':x['text']}, result) | ||
res = list(res) | ||
|
||
os.remove(path) | ||
return json.dumps(res,ensure_ascii=False) | ||
|
||
|
||
|
||
|
||
urls = (u'/ocr',u'OCR', | ||
) | ||
|
||
if __name__ == "__main__": | ||
|
||
app = web.application(urls, globals()) | ||
app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author: lywen | ||
todo:通用函数文件 | ||
""" | ||
import datetime as dt | ||
|
||
def get_date(): | ||
""" | ||
获取当前时间 | ||
""" | ||
try: | ||
now = dt.datetime.now() | ||
nowString = now.strftime('%Y%m%d') | ||
except: | ||
nowString = '000000' | ||
return nowString | ||
|
||
def strdate_to_date(string,format='%Y-%m-%d %H:%M:%S'): | ||
try: | ||
return dt.datetime.strptime(string,format) | ||
except: | ||
return dt.datetime.now() | ||
|
||
def diff_time(beginDate,endDate,format='%Y-%m-%d %H:%M:%S'): | ||
str1Date = strdate_to_date(beginDate,format) | ||
str2Date = strdate_to_date(endDate,format) | ||
times = str2Date - str1Date | ||
return times.total_seconds() | ||
|
||
def get_now(): | ||
""" | ||
获取当前时间 | ||
""" | ||
try: | ||
now = dt.datetime.now() | ||
nowString = now.strftime('%Y-%m-%d %H:%M:%S') | ||
except: | ||
nowString = '00-00-00 00:00:00' | ||
return nowString | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
##图像相关函数 | ||
@author: lywen | ||
""" | ||
import base64 | ||
#import urllib2 | ||
import requests | ||
import numpy as np | ||
import cv2 | ||
from PIL import Image | ||
import sys | ||
import six | ||
import traceback | ||
import uuid | ||
|
||
|
||
|
||
from PIL import Image | ||
def read_img(path): | ||
im = Image.open(path).convert('RGB') | ||
img = np.array(im) | ||
return img | ||
|
||
def convert_image(path=None,string=None): | ||
# Picture ==> base64 encode | ||
if path is not None: | ||
with open(path, 'rb') as f: | ||
base64_data = base64.b64encode(f.read()) | ||
return base64_data | ||
if string is not None: | ||
base64_data = base64.b64encode(string) | ||
return base64_data | ||
|
||
def string_to_array(string): | ||
if check_image_is_valid(string): | ||
buf = six.BytesIO() | ||
buf.write(string) | ||
buf.seek(0) | ||
img = np.array(Image.open(buf)) | ||
return img | ||
else: | ||
return None | ||
|
||
def check_image_is_valid(imageBin): | ||
""" | ||
检查图片是否有效 | ||
""" | ||
if imageBin is None: | ||
return False | ||
imageBuf = np.fromstring(imageBin, dtype=np.uint8) | ||
|
||
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) | ||
if img is None: | ||
return False | ||
|
||
imgH, imgW = img.shape[0], img.shape[1] | ||
if imgH * imgW == 0: | ||
return False | ||
return True | ||
|
||
def base64_to_array(string): | ||
#try: | ||
base64_data = base64.b64decode(string) | ||
buf = six.BytesIO() | ||
buf.write(string) | ||
buf.seek(0) | ||
img = np.array(Image.open(buf)) | ||
return img | ||
|
||
#except: | ||
# return None | ||
|
||
|
||
|
||
def read_url_img(url): | ||
""" | ||
爬取网页图片 | ||
""" | ||
try: | ||
req = requests.get(url,timeout=5)##访问时间超过5s,则超时 | ||
if req.status_code==200: | ||
imgString = req.content | ||
#imgString = req.read() | ||
if check_image_is_valid(imgString): | ||
buf = six.BytesIO() | ||
buf.write(imgString) | ||
buf.seek(0) | ||
img = Image.open(buf).convert('RGB') | ||
return img | ||
else: | ||
return None | ||
except: | ||
traceback.print_exc() | ||
return None | ||
|
||
|
||
def array_to_string(array): | ||
image = Image.fromarray(array) | ||
output = six.BytesIO() | ||
image.save(output,format='png') | ||
contents = output.getvalue() | ||
output.close() | ||
return contents |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import os | ||
darknetRoot = os.path.join("..","darknet")## yolo 安装目录 | ||
pwd = os.getcwd() | ||
yoloCfg = os.path.join(pwd,"models","text.cfg") | ||
yoloWeights = os.path.join(pwd,"models","text.weights") | ||
yoloData = os.path.join(pwd,"models","text.data") | ||
ocrModel = os.path.join(pwd,"models","ocr.pth") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#coding:utf-8 | ||
import sys | ||
sys.path.insert(1, "./crnn") | ||
import torch | ||
import torch.utils.data | ||
from torch.autograd import Variable | ||
from crnn import util | ||
from crnn import dataset | ||
from crnn.models import crnn as crnn | ||
from crnn import keys | ||
#from conf import crnnModelPath | ||
#from conf import GPU | ||
GPU=False | ||
from collections import OrderedDict | ||
from config import ocrModel | ||
def crnnSource(): | ||
alphabet = keys.alphabet | ||
converter = util.strLabelConverter(alphabet) | ||
if torch.cuda.is_available() and GPU: | ||
model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1).cuda() | ||
else: | ||
model = crnn.CRNN(32, 1, len(alphabet)+1, 256, 1).cpu() | ||
|
||
state_dict = torch.load(ocrModel) | ||
new_state_dict = OrderedDict() | ||
for k, v in state_dict.items(): | ||
name = k.replace('module.','') # remove `module.` | ||
new_state_dict[name] = v | ||
# load params | ||
|
||
model.load_state_dict(new_state_dict) | ||
model.eval() | ||
|
||
return model,converter | ||
|
||
##加载模型 | ||
model,converter = crnnSource() | ||
|
||
def crnnOcr(image): | ||
""" | ||
crnn模型,ocr识别 | ||
@@model, | ||
@@converter, | ||
@@im | ||
@@text_recs:text box | ||
""" | ||
scale = image.size[1]*1.0 / 32 | ||
w = image.size[0] / scale | ||
w = int(w) | ||
#print "im size:{},{}".format(image.size,w) | ||
transformer = dataset.resizeNormalize((w, 32)) | ||
if torch.cuda.is_available() and GPU: | ||
image = transformer(image).cuda() | ||
else: | ||
image = transformer(image).cpu() | ||
|
||
image = image.view(1, *image.size()) | ||
image = Variable(image) | ||
model.eval() | ||
preds = model(image) | ||
_, preds = preds.max(2) | ||
preds = preds.transpose(1, 0).contiguous().view(-1) | ||
preds_size = Variable(torch.IntTensor([preds.size(0)])) | ||
sim_pred = converter.decode(preds.data, preds_size.data, raw=False) | ||
|
||
return sim_pred | ||
|
||
|
Oops, something went wrong.