Skip to content

Commit

Permalink
add train,infer,val codes
Browse files Browse the repository at this point in the history
  • Loading branch information
AstarLight committed Nov 11, 2018
1 parent cce1303 commit a35ccde
Show file tree
Hide file tree
Showing 5 changed files with 498 additions and 1 deletion.
162 changes: 162 additions & 0 deletions detector/common/dataset_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@


import os
import codecs
import cv2
import draw_image
import lmdb
import sys
sys.path.append("..")
import ctpn.Net as Net
from torch.utils.data import Dataset


def read_gt_file(path, have_BOM=False):
result = []
if have_BOM:
fp = codecs.open(path, 'r', 'utf-8-sig')
else:
fp = open(path, 'r')
for line in fp.readlines():
pt = line.split(',')
if have_BOM:
box = [int(pt[i]) for i in range(8)]
else:
box = [int(pt[i]) for i in range(8)]
result.append(box)
fp.close()
return result


def create_dataset_icdar2015(img_root, gt_root, output_path):
im_list = os.listdir(img_root)
im_path_list = []
gt_list = []
for im in im_list:
name, _ = os.path.splitext(im)
gt_name = 'gt_' + name + '.txt'
gt_path = os.path.join(gt_root, gt_name)
if not os.path.exists(gt_path):
print('Ground truth file of image {0} not exists.'.format(im))
im_path_list.append(os.path.join(img_root, im))
gt_list.append(gt_path)
assert len(im_path_list) == len(gt_list)
create_dataset(output_path, im_path_list, gt_list)


def scale_img(img, gt, shortest_side=600):
height = img.shape[0]
width = img.shape[1]
scale = float(shortest_side)/float(min(height, width))
img = cv2.resize(img, (0, 0), fx=scale, fy=scale)
if img.shape[0] < img.shape[1] and img.shape[0] != 600:
img = cv2.resize(img, (600, img.shape[1]))
elif img.shape[0] > img.shape[1] and img.shape[1] != 600:
img = cv2.resize(img, (img.shape[0], 600))
elif img.shape[0] != 600:
img = cv2.resize(img, (600, 600))
h_scale = float(img.shape[0])/float(height)
w_scale = float(img.shape[1])/float(width)
scale_gt = []
for box in gt:
scale_box = []
for i in range(len(box)):
if i % 2 == 0:
scale_box.append(int(int(box[i]) * w_scale))
else:
scale_box.append(int(int(box[i]) * h_scale))
scale_gt.append(scale_box)
return img, scale_gt


def check_img(img):
if img is None:
return False
height, width = img.shape[0], img.shape[1]
if height * width == 0:
return False
return True


def write_cache(env, data):
with env.begin(write=True) as e:
for i, l in data.iteritems():
e.put(i, l)


def box_list2str(l):
result = []
for box in l:
if not len(box) % 8 == 0:
return '', False
result.append(','.join(box))
return '|'.join(result), True


def create_dataset(output_path, img_list, gt_list):
assert len(img_list) == len(gt_list)
net = Net.VGG_16()
num = len(img_list)
if not os.path.exists(output_path):
os.makedirs(output_path)
env = lmdb.open(output_path, map_size=1099511627776)
cache = {}
counter = 1
for i in range(num):
img_path = img_list[i]
gt = gt_list[i]
if not os.path.exists(img_path):
print("{0} is not exist.".format(img_path))
continue

if len(gt) == 0:
print("Ground truth of {0} is not exist.".format(img_path))
continue

img = cv2.imread(img_path)
if not check_img(img):
print('Image {0} is not valid.'.format(img_path))
continue

img, gt = scale_img(img, gt)
gt_str = box_list2str(gt)
if not gt_str[1]:
print("Ground truth of {0} is not valid.".format(img_path))
continue

img_key = 'image-%09d' % counter
gt_key = 'gt-%09d' % counter
cache[img_key] = draw_image.np_img2base64(img, img_path)
cache[gt_key] = gt_str[0]
counter += 1
if counter % 100 == 0:
write_cache(env, cache)
cache.clear()
print('Written {0}/{1}'.format(counter, num))
cache['num'] = str(counter - 1)
write_cache(env, cache)
print('Create dataset with {0} image.'.format(counter - 1))


class LmdbDataset(Dataset):
def __init__(self, root, transformer=None):
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
if not self.env:
print("Cannot create lmdb from root {0}.".format(root))
with self.env.begin(write=False) as e:
self.data_num = int(e.get('num'))
self.transformer = transformer

def __len__(self):
return self.data_num

def __getitem__(self, index):
assert index <= len(self), 'Index out of range.'
index += 1
with self.env.begin(write=False) as e:
img_key = 'image-%09d' % index
img_base64 = e.get(img_key)
img = draw_image.base642np_image(img_base64)
gt_key = 'gt-%09d' % index
gt = str(e.get(gt_key))
return img, gt
79 changes: 79 additions & 0 deletions detector/ctpn/Net/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import cv2
import Dataset.port
import Net
import numpy as np
import os
import time
import random


def val(net, criterion, batch_num, using_cuda, logger):
img_root = '../dataset/OCR_dataset/ctpn/test_im'
gt_root = '../dataset/OCR_dataset/ctpn/test_gt'
img_list = os.listdir(img_root)
total_loss = 0
total_cls_loss = 0
total_v_reg_loss = 0
total_o_reg_loss = 0
start_time = time.time()
for im in random.sample(img_list, batch_num):
name, _ = os.path.splitext(im)
gt_name = 'gt_' + name + '.txt'
gt_path = os.path.join(gt_root, gt_name)
if not os.path.exists(gt_path):
print('Ground truth file of image {0} not exists.'.format(im))
continue

gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True)
img = cv2.imread(os.path.join(img_root, im))
img, gt_txt = Dataset.scale_img(img, gt_txt)
tensor_img = img[np.newaxis, :, :, :]
tensor_img = tensor_img.transpose((0, 3, 1, 2))
if using_cuda:
tensor_img = torch.FloatTensor(tensor_img).cuda()
else:
tensor_img = torch.FloatTensor(tensor_img)

vertical_pred, score, side_refinement = net(tensor_img)
del tensor_img
positive = []
negative = []
vertical_reg = []
side_refinement_reg = []
for box in gt_txt:
gt_anchor = Dataset.generate_gt_anchor(img, box)
positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box)
positive += positive1
negative += negative1
vertical_reg += vertical_reg1
side_refinement_reg += side_refinement_reg1

if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0:
batch_num -= 1
continue

loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive,
negative, vertical_reg, side_refinement_reg)
total_loss += loss
total_cls_loss += cls_loss
total_v_reg_loss += v_reg_loss
total_o_reg_loss += o_reg_loss
end_time = time.time()
total_time = end_time - start_time
print('#################### Start evaluate ####################')
print('loss: {0}'.format(total_loss / float(batch_num)))
logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num)))

print('classification loss: {0}'.format(total_cls_loss / float(batch_num)))
logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))

print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))
logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))

print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))
logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))

print('{1} iterations for {0} seconds.'.format(total_time, batch_num))
print('##################### Evaluate end #####################')
print('\n')
40 changes: 39 additions & 1 deletion detector/ctpn/infer.py
Original file line number Diff line number Diff line change
@@ -1 +1,39 @@
import os
import cv2
import numpy as np
import other
import base64
import os
import copy
import Dataset
import Dataset.port as port
import torch
import Net
import torchvision.models
anchor_height = [11, 16, 22, 32, 46, 66, 94, 134, 191, 273]


if __name__ == '__main__':
net = Net.CTPN()
net.load_state_dict(torch.load('./model/ctpn-9-end.model'))
print(net)
net.eval()
im = cv2.imread('../dataset/OCR_dataset/ctpn/test_im/img_0059.jpg')
img = copy.deepcopy(im)
img = img.transpose(2, 0, 1)
img = img[np.newaxis, :, :, :]
img = torch.Tensor(img)
v, score, side = net(img, val=True)
result = []
for i in range(score.shape[0]):
for j in range(score.shape[1]):
for k in range(score.shape[2]):
if score[i, j, k, 1] > 0.6:
result.append((j, k, i, float(score[i, j, k, 1].detach().numpy())))
# print(result)
for box in result:
im = other.draw_box_h_and_c(im, box[1], box[0] * 16 + 7.5, anchor_height[box[2]])
gt = Dataset.port.read_gt_file('../dataset/OCR_dataset/ctpn/test_gt/gt_img_0059.txt')
for gt_box in gt:
im = other.draw_box_4pt(im, gt_box, (255, 0, 0))

cv2.imwrite("./test_result/test.jpg", im)
19 changes: 19 additions & 0 deletions detector/ctpn/lib/create_config_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import ConfigParser


if __name__ == '__main__':
cp = ConfigParser.ConfigParser()
cp.add_section('global')
cp.set('global', 'using_cuda', 'True')
cp.set('global', 'epoch', '30')
cp.set('global', 'gpu_id', '6')
cp.set('global', 'display_file_name', 'False')
cp.set('global', 'display_iter', '1')
cp.set('global', 'val_iter', '30')
cp.set('global', 'save_iter', '100')
cp.add_section('parameter')
cp.set('parameter', 'lr_front', '0.001')
cp.set('parameter', 'lr_behind', '0.0001')
cp.set('parameter', 'change_epoch', '9')
with open('../config', 'w') as fp:
cp.write(fp)
Loading

0 comments on commit a35ccde

Please sign in to comment.