forked from AstarLight/Lets_OCR
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cce1303
commit a35ccde
Showing
5 changed files
with
498 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
|
||
|
||
import os | ||
import codecs | ||
import cv2 | ||
import draw_image | ||
import lmdb | ||
import sys | ||
sys.path.append("..") | ||
import ctpn.Net as Net | ||
from torch.utils.data import Dataset | ||
|
||
|
||
def read_gt_file(path, have_BOM=False): | ||
result = [] | ||
if have_BOM: | ||
fp = codecs.open(path, 'r', 'utf-8-sig') | ||
else: | ||
fp = open(path, 'r') | ||
for line in fp.readlines(): | ||
pt = line.split(',') | ||
if have_BOM: | ||
box = [int(pt[i]) for i in range(8)] | ||
else: | ||
box = [int(pt[i]) for i in range(8)] | ||
result.append(box) | ||
fp.close() | ||
return result | ||
|
||
|
||
def create_dataset_icdar2015(img_root, gt_root, output_path): | ||
im_list = os.listdir(img_root) | ||
im_path_list = [] | ||
gt_list = [] | ||
for im in im_list: | ||
name, _ = os.path.splitext(im) | ||
gt_name = 'gt_' + name + '.txt' | ||
gt_path = os.path.join(gt_root, gt_name) | ||
if not os.path.exists(gt_path): | ||
print('Ground truth file of image {0} not exists.'.format(im)) | ||
im_path_list.append(os.path.join(img_root, im)) | ||
gt_list.append(gt_path) | ||
assert len(im_path_list) == len(gt_list) | ||
create_dataset(output_path, im_path_list, gt_list) | ||
|
||
|
||
def scale_img(img, gt, shortest_side=600): | ||
height = img.shape[0] | ||
width = img.shape[1] | ||
scale = float(shortest_side)/float(min(height, width)) | ||
img = cv2.resize(img, (0, 0), fx=scale, fy=scale) | ||
if img.shape[0] < img.shape[1] and img.shape[0] != 600: | ||
img = cv2.resize(img, (600, img.shape[1])) | ||
elif img.shape[0] > img.shape[1] and img.shape[1] != 600: | ||
img = cv2.resize(img, (img.shape[0], 600)) | ||
elif img.shape[0] != 600: | ||
img = cv2.resize(img, (600, 600)) | ||
h_scale = float(img.shape[0])/float(height) | ||
w_scale = float(img.shape[1])/float(width) | ||
scale_gt = [] | ||
for box in gt: | ||
scale_box = [] | ||
for i in range(len(box)): | ||
if i % 2 == 0: | ||
scale_box.append(int(int(box[i]) * w_scale)) | ||
else: | ||
scale_box.append(int(int(box[i]) * h_scale)) | ||
scale_gt.append(scale_box) | ||
return img, scale_gt | ||
|
||
|
||
def check_img(img): | ||
if img is None: | ||
return False | ||
height, width = img.shape[0], img.shape[1] | ||
if height * width == 0: | ||
return False | ||
return True | ||
|
||
|
||
def write_cache(env, data): | ||
with env.begin(write=True) as e: | ||
for i, l in data.iteritems(): | ||
e.put(i, l) | ||
|
||
|
||
def box_list2str(l): | ||
result = [] | ||
for box in l: | ||
if not len(box) % 8 == 0: | ||
return '', False | ||
result.append(','.join(box)) | ||
return '|'.join(result), True | ||
|
||
|
||
def create_dataset(output_path, img_list, gt_list): | ||
assert len(img_list) == len(gt_list) | ||
net = Net.VGG_16() | ||
num = len(img_list) | ||
if not os.path.exists(output_path): | ||
os.makedirs(output_path) | ||
env = lmdb.open(output_path, map_size=1099511627776) | ||
cache = {} | ||
counter = 1 | ||
for i in range(num): | ||
img_path = img_list[i] | ||
gt = gt_list[i] | ||
if not os.path.exists(img_path): | ||
print("{0} is not exist.".format(img_path)) | ||
continue | ||
|
||
if len(gt) == 0: | ||
print("Ground truth of {0} is not exist.".format(img_path)) | ||
continue | ||
|
||
img = cv2.imread(img_path) | ||
if not check_img(img): | ||
print('Image {0} is not valid.'.format(img_path)) | ||
continue | ||
|
||
img, gt = scale_img(img, gt) | ||
gt_str = box_list2str(gt) | ||
if not gt_str[1]: | ||
print("Ground truth of {0} is not valid.".format(img_path)) | ||
continue | ||
|
||
img_key = 'image-%09d' % counter | ||
gt_key = 'gt-%09d' % counter | ||
cache[img_key] = draw_image.np_img2base64(img, img_path) | ||
cache[gt_key] = gt_str[0] | ||
counter += 1 | ||
if counter % 100 == 0: | ||
write_cache(env, cache) | ||
cache.clear() | ||
print('Written {0}/{1}'.format(counter, num)) | ||
cache['num'] = str(counter - 1) | ||
write_cache(env, cache) | ||
print('Create dataset with {0} image.'.format(counter - 1)) | ||
|
||
|
||
class LmdbDataset(Dataset): | ||
def __init__(self, root, transformer=None): | ||
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) | ||
if not self.env: | ||
print("Cannot create lmdb from root {0}.".format(root)) | ||
with self.env.begin(write=False) as e: | ||
self.data_num = int(e.get('num')) | ||
self.transformer = transformer | ||
|
||
def __len__(self): | ||
return self.data_num | ||
|
||
def __getitem__(self, index): | ||
assert index <= len(self), 'Index out of range.' | ||
index += 1 | ||
with self.env.begin(write=False) as e: | ||
img_key = 'image-%09d' % index | ||
img_base64 = e.get(img_key) | ||
img = draw_image.base642np_image(img_base64) | ||
gt_key = 'gt-%09d' % index | ||
gt = str(e.get(gt_key)) | ||
return img, gt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
import cv2 | ||
import Dataset.port | ||
import Net | ||
import numpy as np | ||
import os | ||
import time | ||
import random | ||
|
||
|
||
def val(net, criterion, batch_num, using_cuda, logger): | ||
img_root = '../dataset/OCR_dataset/ctpn/test_im' | ||
gt_root = '../dataset/OCR_dataset/ctpn/test_gt' | ||
img_list = os.listdir(img_root) | ||
total_loss = 0 | ||
total_cls_loss = 0 | ||
total_v_reg_loss = 0 | ||
total_o_reg_loss = 0 | ||
start_time = time.time() | ||
for im in random.sample(img_list, batch_num): | ||
name, _ = os.path.splitext(im) | ||
gt_name = 'gt_' + name + '.txt' | ||
gt_path = os.path.join(gt_root, gt_name) | ||
if not os.path.exists(gt_path): | ||
print('Ground truth file of image {0} not exists.'.format(im)) | ||
continue | ||
|
||
gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True) | ||
img = cv2.imread(os.path.join(img_root, im)) | ||
img, gt_txt = Dataset.scale_img(img, gt_txt) | ||
tensor_img = img[np.newaxis, :, :, :] | ||
tensor_img = tensor_img.transpose((0, 3, 1, 2)) | ||
if using_cuda: | ||
tensor_img = torch.FloatTensor(tensor_img).cuda() | ||
else: | ||
tensor_img = torch.FloatTensor(tensor_img) | ||
|
||
vertical_pred, score, side_refinement = net(tensor_img) | ||
del tensor_img | ||
positive = [] | ||
negative = [] | ||
vertical_reg = [] | ||
side_refinement_reg = [] | ||
for box in gt_txt: | ||
gt_anchor = Dataset.generate_gt_anchor(img, box) | ||
positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box) | ||
positive += positive1 | ||
negative += negative1 | ||
vertical_reg += vertical_reg1 | ||
side_refinement_reg += side_refinement_reg1 | ||
|
||
if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: | ||
batch_num -= 1 | ||
continue | ||
|
||
loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, | ||
negative, vertical_reg, side_refinement_reg) | ||
total_loss += loss | ||
total_cls_loss += cls_loss | ||
total_v_reg_loss += v_reg_loss | ||
total_o_reg_loss += o_reg_loss | ||
end_time = time.time() | ||
total_time = end_time - start_time | ||
print('#################### Start evaluate ####################') | ||
print('loss: {0}'.format(total_loss / float(batch_num))) | ||
logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num))) | ||
|
||
print('classification loss: {0}'.format(total_cls_loss / float(batch_num))) | ||
logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) | ||
|
||
print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) | ||
logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) | ||
|
||
print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) | ||
logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) | ||
|
||
print('{1} iterations for {0} seconds.'.format(total_time, batch_num)) | ||
print('##################### Evaluate end #####################') | ||
print('\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,39 @@ | ||
import os | ||
import cv2 | ||
import numpy as np | ||
import other | ||
import base64 | ||
import os | ||
import copy | ||
import Dataset | ||
import Dataset.port as port | ||
import torch | ||
import Net | ||
import torchvision.models | ||
anchor_height = [11, 16, 22, 32, 46, 66, 94, 134, 191, 273] | ||
|
||
|
||
if __name__ == '__main__': | ||
net = Net.CTPN() | ||
net.load_state_dict(torch.load('./model/ctpn-9-end.model')) | ||
print(net) | ||
net.eval() | ||
im = cv2.imread('../dataset/OCR_dataset/ctpn/test_im/img_0059.jpg') | ||
img = copy.deepcopy(im) | ||
img = img.transpose(2, 0, 1) | ||
img = img[np.newaxis, :, :, :] | ||
img = torch.Tensor(img) | ||
v, score, side = net(img, val=True) | ||
result = [] | ||
for i in range(score.shape[0]): | ||
for j in range(score.shape[1]): | ||
for k in range(score.shape[2]): | ||
if score[i, j, k, 1] > 0.6: | ||
result.append((j, k, i, float(score[i, j, k, 1].detach().numpy()))) | ||
# print(result) | ||
for box in result: | ||
im = other.draw_box_h_and_c(im, box[1], box[0] * 16 + 7.5, anchor_height[box[2]]) | ||
gt = Dataset.port.read_gt_file('../dataset/OCR_dataset/ctpn/test_gt/gt_img_0059.txt') | ||
for gt_box in gt: | ||
im = other.draw_box_4pt(im, gt_box, (255, 0, 0)) | ||
|
||
cv2.imwrite("./test_result/test.jpg", im) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import ConfigParser | ||
|
||
|
||
if __name__ == '__main__': | ||
cp = ConfigParser.ConfigParser() | ||
cp.add_section('global') | ||
cp.set('global', 'using_cuda', 'True') | ||
cp.set('global', 'epoch', '30') | ||
cp.set('global', 'gpu_id', '6') | ||
cp.set('global', 'display_file_name', 'False') | ||
cp.set('global', 'display_iter', '1') | ||
cp.set('global', 'val_iter', '30') | ||
cp.set('global', 'save_iter', '100') | ||
cp.add_section('parameter') | ||
cp.set('parameter', 'lr_front', '0.001') | ||
cp.set('parameter', 'lr_behind', '0.0001') | ||
cp.set('parameter', 'change_epoch', '9') | ||
with open('../config', 'w') as fp: | ||
cp.write(fp) |
Oops, something went wrong.