From a35ccdeb9aed63d9cad37631e13081735993c352 Mon Sep 17 00:00:00 2001 From: AstarLight Date: Sun, 11 Nov 2018 04:55:12 -0500 Subject: [PATCH] add train,infer,val codes --- detector/common/dataset_handler.py | 162 +++++++++++++++++++ detector/ctpn/Net/evaluate.py | 79 ++++++++++ detector/ctpn/infer.py | 40 ++++- detector/ctpn/lib/create_config_file.py | 19 +++ detector/ctpn/train.py | 199 ++++++++++++++++++++++++ 5 files changed, 498 insertions(+), 1 deletion(-) create mode 100644 detector/common/dataset_handler.py create mode 100644 detector/ctpn/Net/evaluate.py create mode 100644 detector/ctpn/lib/create_config_file.py diff --git a/detector/common/dataset_handler.py b/detector/common/dataset_handler.py new file mode 100644 index 0000000..5abaadc --- /dev/null +++ b/detector/common/dataset_handler.py @@ -0,0 +1,162 @@ + + +import os +import codecs +import cv2 +import draw_image +import lmdb +import sys +sys.path.append("..") +import ctpn.Net as Net +from torch.utils.data import Dataset + + +def read_gt_file(path, have_BOM=False): + result = [] + if have_BOM: + fp = codecs.open(path, 'r', 'utf-8-sig') + else: + fp = open(path, 'r') + for line in fp.readlines(): + pt = line.split(',') + if have_BOM: + box = [int(pt[i]) for i in range(8)] + else: + box = [int(pt[i]) for i in range(8)] + result.append(box) + fp.close() + return result + + +def create_dataset_icdar2015(img_root, gt_root, output_path): + im_list = os.listdir(img_root) + im_path_list = [] + gt_list = [] + for im in im_list: + name, _ = os.path.splitext(im) + gt_name = 'gt_' + name + '.txt' + gt_path = os.path.join(gt_root, gt_name) + if not os.path.exists(gt_path): + print('Ground truth file of image {0} not exists.'.format(im)) + im_path_list.append(os.path.join(img_root, im)) + gt_list.append(gt_path) + assert len(im_path_list) == len(gt_list) + create_dataset(output_path, im_path_list, gt_list) + + +def scale_img(img, gt, shortest_side=600): + height = img.shape[0] + width = img.shape[1] + scale = float(shortest_side)/float(min(height, width)) + img = cv2.resize(img, (0, 0), fx=scale, fy=scale) + if img.shape[0] < img.shape[1] and img.shape[0] != 600: + img = cv2.resize(img, (600, img.shape[1])) + elif img.shape[0] > img.shape[1] and img.shape[1] != 600: + img = cv2.resize(img, (img.shape[0], 600)) + elif img.shape[0] != 600: + img = cv2.resize(img, (600, 600)) + h_scale = float(img.shape[0])/float(height) + w_scale = float(img.shape[1])/float(width) + scale_gt = [] + for box in gt: + scale_box = [] + for i in range(len(box)): + if i % 2 == 0: + scale_box.append(int(int(box[i]) * w_scale)) + else: + scale_box.append(int(int(box[i]) * h_scale)) + scale_gt.append(scale_box) + return img, scale_gt + + +def check_img(img): + if img is None: + return False + height, width = img.shape[0], img.shape[1] + if height * width == 0: + return False + return True + + +def write_cache(env, data): + with env.begin(write=True) as e: + for i, l in data.iteritems(): + e.put(i, l) + + +def box_list2str(l): + result = [] + for box in l: + if not len(box) % 8 == 0: + return '', False + result.append(','.join(box)) + return '|'.join(result), True + + +def create_dataset(output_path, img_list, gt_list): + assert len(img_list) == len(gt_list) + net = Net.VGG_16() + num = len(img_list) + if not os.path.exists(output_path): + os.makedirs(output_path) + env = lmdb.open(output_path, map_size=1099511627776) + cache = {} + counter = 1 + for i in range(num): + img_path = img_list[i] + gt = gt_list[i] + if not os.path.exists(img_path): + print("{0} is not exist.".format(img_path)) + continue + + if len(gt) == 0: + print("Ground truth of {0} is not exist.".format(img_path)) + continue + + img = cv2.imread(img_path) + if not check_img(img): + print('Image {0} is not valid.'.format(img_path)) + continue + + img, gt = scale_img(img, gt) + gt_str = box_list2str(gt) + if not gt_str[1]: + print("Ground truth of {0} is not valid.".format(img_path)) + continue + + img_key = 'image-%09d' % counter + gt_key = 'gt-%09d' % counter + cache[img_key] = draw_image.np_img2base64(img, img_path) + cache[gt_key] = gt_str[0] + counter += 1 + if counter % 100 == 0: + write_cache(env, cache) + cache.clear() + print('Written {0}/{1}'.format(counter, num)) + cache['num'] = str(counter - 1) + write_cache(env, cache) + print('Create dataset with {0} image.'.format(counter - 1)) + + +class LmdbDataset(Dataset): + def __init__(self, root, transformer=None): + self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) + if not self.env: + print("Cannot create lmdb from root {0}.".format(root)) + with self.env.begin(write=False) as e: + self.data_num = int(e.get('num')) + self.transformer = transformer + + def __len__(self): + return self.data_num + + def __getitem__(self, index): + assert index <= len(self), 'Index out of range.' + index += 1 + with self.env.begin(write=False) as e: + img_key = 'image-%09d' % index + img_base64 = e.get(img_key) + img = draw_image.base642np_image(img_base64) + gt_key = 'gt-%09d' % index + gt = str(e.get(gt_key)) + return img, gt diff --git a/detector/ctpn/Net/evaluate.py b/detector/ctpn/Net/evaluate.py new file mode 100644 index 0000000..438ce5e --- /dev/null +++ b/detector/ctpn/Net/evaluate.py @@ -0,0 +1,79 @@ +import torch +import cv2 +import Dataset.port +import Net +import numpy as np +import os +import time +import random + + +def val(net, criterion, batch_num, using_cuda, logger): + img_root = '../dataset/OCR_dataset/ctpn/test_im' + gt_root = '../dataset/OCR_dataset/ctpn/test_gt' + img_list = os.listdir(img_root) + total_loss = 0 + total_cls_loss = 0 + total_v_reg_loss = 0 + total_o_reg_loss = 0 + start_time = time.time() + for im in random.sample(img_list, batch_num): + name, _ = os.path.splitext(im) + gt_name = 'gt_' + name + '.txt' + gt_path = os.path.join(gt_root, gt_name) + if not os.path.exists(gt_path): + print('Ground truth file of image {0} not exists.'.format(im)) + continue + + gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True) + img = cv2.imread(os.path.join(img_root, im)) + img, gt_txt = Dataset.scale_img(img, gt_txt) + tensor_img = img[np.newaxis, :, :, :] + tensor_img = tensor_img.transpose((0, 3, 1, 2)) + if using_cuda: + tensor_img = torch.FloatTensor(tensor_img).cuda() + else: + tensor_img = torch.FloatTensor(tensor_img) + + vertical_pred, score, side_refinement = net(tensor_img) + del tensor_img + positive = [] + negative = [] + vertical_reg = [] + side_refinement_reg = [] + for box in gt_txt: + gt_anchor = Dataset.generate_gt_anchor(img, box) + positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box) + positive += positive1 + negative += negative1 + vertical_reg += vertical_reg1 + side_refinement_reg += side_refinement_reg1 + + if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: + batch_num -= 1 + continue + + loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, + negative, vertical_reg, side_refinement_reg) + total_loss += loss + total_cls_loss += cls_loss + total_v_reg_loss += v_reg_loss + total_o_reg_loss += o_reg_loss + end_time = time.time() + total_time = end_time - start_time + print('#################### Start evaluate ####################') + print('loss: {0}'.format(total_loss / float(batch_num))) + logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num))) + + print('classification loss: {0}'.format(total_cls_loss / float(batch_num))) + logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) + + print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) + logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) + + print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) + logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) + + print('{1} iterations for {0} seconds.'.format(total_time, batch_num)) + print('##################### Evaluate end #####################') + print('\n') diff --git a/detector/ctpn/infer.py b/detector/ctpn/infer.py index 0aba17d..cb3e7bd 100644 --- a/detector/ctpn/infer.py +++ b/detector/ctpn/infer.py @@ -1 +1,39 @@ -import os \ No newline at end of file +import cv2 +import numpy as np +import other +import base64 +import os +import copy +import Dataset +import Dataset.port as port +import torch +import Net +import torchvision.models +anchor_height = [11, 16, 22, 32, 46, 66, 94, 134, 191, 273] + + +if __name__ == '__main__': + net = Net.CTPN() + net.load_state_dict(torch.load('./model/ctpn-9-end.model')) + print(net) + net.eval() + im = cv2.imread('../dataset/OCR_dataset/ctpn/test_im/img_0059.jpg') + img = copy.deepcopy(im) + img = img.transpose(2, 0, 1) + img = img[np.newaxis, :, :, :] + img = torch.Tensor(img) + v, score, side = net(img, val=True) + result = [] + for i in range(score.shape[0]): + for j in range(score.shape[1]): + for k in range(score.shape[2]): + if score[i, j, k, 1] > 0.6: + result.append((j, k, i, float(score[i, j, k, 1].detach().numpy()))) + # print(result) + for box in result: + im = other.draw_box_h_and_c(im, box[1], box[0] * 16 + 7.5, anchor_height[box[2]]) + gt = Dataset.port.read_gt_file('../dataset/OCR_dataset/ctpn/test_gt/gt_img_0059.txt') + for gt_box in gt: + im = other.draw_box_4pt(im, gt_box, (255, 0, 0)) + + cv2.imwrite("./test_result/test.jpg", im) \ No newline at end of file diff --git a/detector/ctpn/lib/create_config_file.py b/detector/ctpn/lib/create_config_file.py new file mode 100644 index 0000000..2d0dcd3 --- /dev/null +++ b/detector/ctpn/lib/create_config_file.py @@ -0,0 +1,19 @@ +import ConfigParser + + +if __name__ == '__main__': + cp = ConfigParser.ConfigParser() + cp.add_section('global') + cp.set('global', 'using_cuda', 'True') + cp.set('global', 'epoch', '30') + cp.set('global', 'gpu_id', '6') + cp.set('global', 'display_file_name', 'False') + cp.set('global', 'display_iter', '1') + cp.set('global', 'val_iter', '30') + cp.set('global', 'save_iter', '100') + cp.add_section('parameter') + cp.set('parameter', 'lr_front', '0.001') + cp.set('parameter', 'lr_behind', '0.0001') + cp.set('parameter', 'change_epoch', '9') + with open('../config', 'w') as fp: + cp.write(fp) diff --git a/detector/ctpn/train.py b/detector/ctpn/train.py index e69de29..fc3a861 100644 --- a/detector/ctpn/train.py +++ b/detector/ctpn/train.py @@ -0,0 +1,199 @@ +import torch.optim as optim +import torch +import cv2 +import Dataset.port +import Net +import numpy as np +import os +import other +import ConfigParser +import time +import evaluate +import logging +import datetime +import copy +import random + +if __name__ == '__main__': + cf = ConfigParser.ConfigParser() + cf.read('./config') + + log_dir = './logs' + + if not os.path.exists(log_dir): + os.mkdir(log_dir) + + logger = logging.getLogger(__name__) + logger.setLevel(level=logging.DEBUG) + log_file_name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + '.log' + log_handler = logging.FileHandler(os.path.join(log_dir, log_file_name), 'w') + log_format = formatter = logging.Formatter('%(asctime)s: %(message)s') + log_handler.setFormatter(log_format) + logger.addHandler(log_handler) + + gpu_id = cf.get('global', 'gpu_id') + epoch = cf.getint('global', 'epoch') + logger.info('Total epoch: {0}'.format(epoch)) + + using_cuda = cf.getboolean('global', 'using_cuda') + display_img_name = cf.getboolean('global', 'display_file_name') + display_iter = cf.getint('global', 'display_iter') + val_iter = cf.getint('global', 'val_iter') + save_iter = cf.getint('global', 'save_iter') + + lr_front = cf.getfloat('parameter', 'lr_front') + lr_behind = cf.getfloat('parameter', 'lr_behind') + change_epoch = cf.getint('parameter', 'change_epoch') - 1 + logger.info('Learning rate: {0}, {1}, change epoch: {2}'.format(lr_front, lr_behind, change_epoch + 1)) + print('Using gpu id(available if use cuda): {0}'.format(gpu_id)) + print('Train epoch: {0}'.format(epoch)) + print('Use CUDA: {0}'.format(using_cuda)) + + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id + no_grad = [ + 'cnn.VGG_16.convolution1_1.weight', + 'cnn.VGG_16.convolution1_1.bias', + 'cnn.VGG_16.convolution1_2.weight', + 'cnn.VGG_16.convolution1_2.bias' + ] + + net = Net.CTPN() + for name, value in net.named_parameters(): + if name in no_grad: + value.requires_grad = False + else: + value.requires_grad = True + # for name, value in net.named_parameters(): + # print('name: {0}, grad: {1}'.format(name, value.requires_grad)) + net.load_state_dict(torch.load('./other/vgg16.model')) + # net.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) + other.init_weight(net) + if using_cuda: + net.cuda() + net.train() + print(net) + + criterion = Net.CTPN_Loss(using_cuda=using_cuda) + + img_root = '../dataset/OCR_dataset/ctpn/train_im2' # icdar15 + gt_root = '../dataset/OCR_dataset/ctpn/train_gt2' + + img_root1 = '../dataset/OCR_dataset/ctpn/train_im' # MSRA_TD500 + gt_root1 = '../dataset/OCR_dataset/ctpn/train_gt' + + im_list = [] + im_list.append(os.listdir(img_root1)) + im_list.append(os.listdir(img_root)) + total_iter = len(im_list[0]) + len(im_list[1]) + + for i in range(epoch): + if i >= change_epoch: + lr = lr_behind + else: + lr = lr_front + optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) + iteration = 1 + total_loss = 0 + total_cls_loss = 0 + total_v_reg_loss = 0 + total_o_reg_loss = 0 + start_time = time.time() + for j in range(len(im_list)): + random.shuffle(im_list[j]) + # print(random_im_list) + for im in im_list[j]: + name, _ = os.path.splitext(im) + gt_name = 'gt_' + name + '.txt' + if j == 1: + gt_path = os.path.join(gt_root, gt_name) + else: + gt_path = os.path.join(gt_root1, gt_name) + if not os.path.exists(gt_path): + print('Ground truth file of image {0} not exists.'.format(im)) + continue + + if j == 1: + gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True) # MSRA dataset have BOM + print("processing image %s" % os.path.join(img_root, im)) + img = cv2.imread(os.path.join(img_root, im)) + if display_img_name: + print(os.path.join(img_root, im)) + else: + gt_txt = Dataset.port.read_gt_file(gt_path) + #print("processing image %s" % os.path.join(img_root1, im)) + img = cv2.imread(os.path.join(img_root1, im)) + if display_img_name: + print(os.path.join(img_root1, im)) + img, gt_txt = Dataset.scale_img(img, gt_txt) + tensor_img = img[np.newaxis, :, :, :] + tensor_img = tensor_img.transpose((0, 3, 1, 2)) + if using_cuda: + tensor_img = torch.FloatTensor(tensor_img).cuda() + else: + tensor_img = torch.FloatTensor(tensor_img) + + vertical_pred, score, side_refinement = net(tensor_img) + del tensor_img + + # transform bbox gt to anchor gt for training + positive = [] + negative = [] + vertical_reg = [] + side_refinement_reg = [] + # loop all bbox in one image + for box in gt_txt: + # generate anchors from one bbox + gt_anchor = Dataset.generate_gt_anchor(img, box) + positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box) + positive += positive1 + negative += negative1 + vertical_reg += vertical_reg1 + side_refinement_reg += side_refinement_reg1 + + if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: + iteration += 1 + continue + optimizer.zero_grad() + loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, + negative, vertical_reg, side_refinement_reg) + loss.backward() + optimizer.step() + iteration += 1 + total_loss += loss + total_cls_loss += cls_loss + total_v_reg_loss += v_reg_loss + total_o_reg_loss += o_reg_loss + + if iteration % display_iter == 0: + end_time = time.time() + total_time = end_time - start_time + print('Epoch: {2}/{3}, Iteration: {0}/{1}, loss: {4}, cls_loss: {5}, v_reg_loss: {6}, o_reg_loss: {7}, {8}'. + format(iteration, total_iter, i, epoch, total_loss / display_iter, total_cls_loss / display_iter, + total_v_reg_loss / display_iter, total_o_reg_loss / display_iter, os.path.join(img_root1, im))) + + logger.info('Epoch: {2}/{3}, Iteration: {0}/{1}'.format(iteration, total_iter, i, epoch)) + logger.info('loss: {0}'.format(total_loss / display_iter)) + logger.info('classification loss: {0}'.format(total_cls_loss / display_iter)) + logger.info('vertical regression loss: {0}'.format(total_v_reg_loss / display_iter)) + logger.info('side-refinement regression loss: {0}'.format(total_o_reg_loss / display_iter)) + + total_loss = 0 + total_cls_loss = 0 + total_v_reg_loss = 0 + total_o_reg_loss = 0 + start_time = time.time() + + if iteration % val_iter == 0: + net.eval() + logger.info('Start evaluate at {0} epoch {1} iteration.'.format(i, iteration)) + val_func.val(net, criterion, 10, using_cuda, logger) + logger.info('End evaluate.') + net.train() + start_time = time.time() + + if iteration % save_iter == 0: + print('Model saved at ./model/ctpn-{0}-{1}.model'.format(i, iteration)) + torch.save(net.state_dict(), './model/ctpn-{0}-{1}.model'.format(i, iteration)) + + print('Model saved at ./model/ctpn-{0}-end.model'.format(i)) + torch.save(net.state_dict(), './model/ctpn-{0}-end.model'.format(i))