Skip to content

Commit

Permalink
build release crnn
Browse files Browse the repository at this point in the history
  • Loading branch information
AstarLight committed Feb 1, 2019
1 parent 9806bbe commit ea965f4
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 176 deletions.
11 changes: 6 additions & 5 deletions recognizer/crnn/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
using_cuda = True
keep_ratio = False
gpu_id = '5'
model_dir = './bs256_model'
data_worker = 4
batch_size = 256
model_dir = './w160_bs64_model'
data_worker = 5
batch_size = 64
img_height = 32
img_width = 100
img_width = 160
alphabet = alphabets.alphabet
epoch = 20
display_interval = 20
save_interval = 10000
save_interval = 4000
test_interval = 2000
test_disp = 20
test_batch_num = 32
lr = 0.0001
beta1 = 0.5
infer_img_w = 160
58 changes: 0 additions & 58 deletions recognizer/crnn/evaluate.py

This file was deleted.

43 changes: 20 additions & 23 deletions recognizer/crnn/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,35 @@
import torch
import os
from torch.autograd import Variable
import lib.utils
import lib.convert
import lib.dataset
from PIL import Image
import Net.net as Net
import alphabets
import numpy as np
import cv2
import sys
import Config

crnn_model_path = './model/mixed_second_finetune_acc97p7.pth'
os.environ['CUDA_VISIBLE_DEVICES'] = "4"

crnn_model_path = './w160_bs64_model/netCRNN_4_32000.pth'
IMG_ROOT = './test_images'
running_mode = 'gpu'
alphabet = alphabets.alphabet
nclass = len(alphabet) + 1

# Testing images are scaled to have height 32. Widths are
# proportionally scaled with heights, but at least 100 pixels
def scale_img_para(img, min_width=100, fixed_height=32):
height = img.size[1]
width = img.size[0]
scale = float(fixed_height)/height
w = int(width * scale)
if w < min_width:
w = min_width

return w, fixed_height


def crnn_recognition(cropped_image, model):
converter = lib.utils.strLabelConverter(alphabet)
converter = lib.convert.strLabelConverter(alphabet)

image = cropped_image.convert('L')

##
#w = int(image.size[0] / (280 * 1.0 / 160))
w, h = scale_img_para(image)
transformer = lib.dataset.resizeNormalize((w, h))
### Testing images are scaled to have height 32. Widths are
# proportionally scaled with heights, but at least 100 pixels
w = int(image.size[0] / (280 * 1.0 / Config.infer_img_w))
#scale = image.size[1] * 1.0 / Config.img_height
#w = int(image.size[0] / scale)

transformer = lib.dataset.resizeNormalize((w, Config.img_height))
image = transformer(image)
if torch.cuda.is_available():
image = image.cuda()
Expand Down Expand Up @@ -68,13 +61,17 @@ def crnn_recognition(cropped_image, model):

print('loading pretrained model from {0}'.format(crnn_model_path))

files = os.listdir(IMG_ROOT)
files = sorted(os.listdir(IMG_ROOT))
for file in files:
started = time.time()
full_path = os.path.join(IMG_ROOT, file)
print("=============================================")
print("ocr image is %s" % full_path)
image = Image.open(full_path)

crnn_recognition(image, model)
finished = time.time()
print('elapsed time: {0}'.format(finished - started))
print('elapsed time: {0}'.format(finished - started))

sys.exit()

89 changes: 8 additions & 81 deletions recognizer/crnn/lib/dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@

#!/usr/bin/python
# encoding: utf-8

import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
import lmdb
import six
import sys
from PIL import Image
import numpy as np
import lmdb
import cv2
import os


class lmdbDataset(Dataset):
Expand All @@ -31,7 +29,9 @@ def __init__(self, root=None, transform=None, target_transform=None):
sys.exit(0)

with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))

str = 'num-samples'
nSamples = int(txn.get(str.encode()))
self.nSamples = nSamples

self.transform = transform
Expand Down Expand Up @@ -60,7 +60,7 @@ def __getitem__(self, index):
img = self.transform(img)

label_key = 'label-%09d' % index
label = str(txn.get(label_key.encode()))
label = txn.get(label_key.encode())

if self.target_transform is not None:
label = self.target_transform(label)
Expand Down Expand Up @@ -110,16 +110,14 @@ def __len__(self):

class alignCollate(object):

def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
def __init__(self, imgH=32, imgW=256, keep_ratio=False, min_ratio=1):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio

def __call__(self, batch):
images, labels = zip(*batch)
# print(Image._show(images[0]))

imgH = self.imgH
imgW = self.imgW
if self.keep_ratio:
Expand All @@ -139,76 +137,5 @@ def __call__(self, batch):
return images, labels


def checkImageIsValid(imageBin): # check image and transform image to gray scale
if imageBin is None:
return False
try:
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
except cv2.error:
print('Error!')
return False
if img is None:
return False
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True


def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.iteritems():
txn.put(k, v)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 1
for i in xrange(nSamples):
imagePath = imagePathList[i]
label = labelList[i]
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'r') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue

imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)


def loadData(v, data):
v.data.resize_(data.size()).copy_(data)

v.data.resize_(data.size()).copy_(data)
18 changes: 9 additions & 9 deletions recognizer/crnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,17 @@
from torch.autograd import Variable
import Net.net as Net
import torch.optim as optim
import dataset
import evaluate


def val(net, dataset, criterion, max_iter=10):
def val(net, da, criterion, max_iter=100):
print('Start val')

for p in net.parameters():
p.requires_grad = False

net.eval()
data_loader = torch.utils.data.DataLoader(
dataset, shuffle=True, batch_size=Config.batch_size, num_workers=int(Config.data_worker),
collate_fn=lib.dataset.alignCollate(imgH=Config.img_height, imgW=Config.img_width, keep_ratio=True))
da, shuffle=True, batch_size=Config.batch_size, num_workers=int(Config.data_worker))
val_iter = iter(data_loader)

i = 0
Expand Down Expand Up @@ -92,6 +89,9 @@ def trainBatch(net, criterion, optimizer, train_iter):
if not os.path.exists(Config.model_dir):
os.mkdir(Config.model_dir)

print("image scale: [%s,%s]\nmodel_save_path: %s\ngpu_id: %s\nbatch_size: %s" %
(Config.img_height, Config.img_width, Config.model_dir, Config.gpu_id, Config.batch_size))

random.seed(Config.random_seed)
np.random.seed(Config.random_seed)
torch.manual_seed(Config.random_seed)
Expand All @@ -106,11 +106,11 @@ def trainBatch(net, criterion, optimizer, train_iter):
cuda = False
print('Using cpu mode')

train_dataset = dataset.lmdbDataset(root=Config.train_data)
test_dataset = dataset.lmdbDataset(root=Config.test_data)
train_dataset = lib.dataset.lmdbDataset(root=Config.train_data)
test_dataset = lib.dataset.lmdbDataset(root=Config.test_data, transform=lib.dataset.resizeNormalize((Config.img_width, Config.img_height)))
assert train_dataset

# images will be resize to 32*160
# images will be resize to 32*100
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=Config.batch_size,
shuffle=True,
Expand Down Expand Up @@ -170,7 +170,7 @@ def trainBatch(net, criterion, optimizer, train_iter):
loss_avg.reset()

if i % Config.test_interval == 0:
val(net, test_dataset, criterion, max_iter=Config.test_batch_num)
val(net, test_dataset, criterion)

# do checkpointing
if i % Config.save_interval == 0:
Expand Down
Binary file not shown.

0 comments on commit ea965f4

Please sign in to comment.