Skip to content

Commit

Permalink
修正部分bug,新增ocr训练
Browse files Browse the repository at this point in the history
  • Loading branch information
wenlihaoyu committed Nov 13, 2018
1 parent ae7339c commit 736c16c
Show file tree
Hide file tree
Showing 18 changed files with 2,088 additions and 17 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
*.so
*.egg
*.egg-info
*.pth
*.pb
*.pbtxt
*.weights
dist
buil
.DS_Store*
.ipynb_checkpoints
__pycache__

darknet
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
DETECTANGLE=True##是否进行文字方向检测
LSTMFLAG = True##OCR模型是否调用LSTM层
GPU = True##OCR 是否启用GPU
GPUID=0##调用GPU序号
chinsesModel = True##模型选择 True:中英文模型 False:英文模型

if chinsesModel:
if LSTMFLAG:
ocrModel = os.path.join(pwd,"models","ocr-lstm.pth")
Expand Down
3 changes: 1 addition & 2 deletions crnn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class strLabelConverter(object):

def __init__(self, alphabet):
self.alphabet = alphabet + u'-' # for `-1` index
self.alphabet = alphabet + ' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
Expand All @@ -19,7 +19,6 @@ def encode(self, text, depth=0):
length = []
result=[]
for str in text:
str = unicode(str,"utf8")
length.append(len(str))
for char in str:
#print(char)
Expand Down
12 changes: 8 additions & 4 deletions darknet_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
pwd = os.getcwd()
import numpy as np
from PIL import Image
from config import yoloCfg,yoloWeights,yoloData,yoloData,darknetRoot

from config import yoloCfg,yoloWeights,yoloData,yoloData,darknetRoot,GPU,GPUID
os.chdir(darknetRoot)
sys.path.append('python')
import darknet as dn
Expand Down Expand Up @@ -52,13 +53,16 @@ def to_box(r):


import pdb
#dn.set_gpu(0)
if GPU:
try:
dn.set_gpu(GPUID)
except:
pass
net = dn.load_net(yoloCfg.encode('utf-8'), yoloWeights.encode('utf-8'), 0)
meta = dn.load_meta(yoloData.encode('utf-8'))
os.chdir(pwd)
def text_detect(img):
inputBlob = cv2.dnn.blobFromImage(img, scalefactor=0.00390625, size=(608, 608),swapRB=True ,crop=False);

r = detect_np(net, meta, img,thresh=0.1, hier_thresh=0.5, nms=0.8)
r = detect_np(net, meta, img,thresh=0, hier_thresh=0.5, nms=None)##输出所有box,与opencv dnn统一
bboxes = to_box(r)
return bboxes
15 changes: 9 additions & 6 deletions detector/detectors.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
#coding:utf-8
from detector.other import normalize
import numpy as np
import numpy as np
from config import GPUID,GPU
from detector.utils.cython_nms import nms as cython_nms
try:
from detector.utils.gpu_nms import gpu_nms
except:
gpu_nms =cython_nms
##优先加载编译对GPU编译的gpu_nms 如果不想调用GPU,在程序启动执行os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if GPU:
try:
from detector.utils.gpu_nms import gpu_nms
except:
gpu_nms =cython_nms

def nms(dets, thresh):
if dets.shape[0] == 0:
return []

try:
return gpu_nms(dets, thresh, device_id=0)
if GPU and GPUID is not None:
return gpu_nms(dets, thresh, device_id=GPUID)
except:
return cython_nms(dets, thresh)

Expand Down
Loading

0 comments on commit 736c16c

Please sign in to comment.