Skip to content

Commit

Permalink
add dataset ctw and tt
Browse files Browse the repository at this point in the history
  • Loading branch information
RoseSakurai committed Apr 1, 2021
1 parent 2501952 commit 5c45dfa
Show file tree
Hide file tree
Showing 16 changed files with 278 additions and 59 deletions.
58 changes: 58 additions & 0 deletions config/pan/pan_r50_ctw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
model = dict(
type='PSENet',
backbone=dict(
type='resnet50',
pretrained=True
),
neck=dict(
type='FPN',
in_channels=(256, 512, 1024, 2048),
out_channels=128
),
detection_head=dict(
type='PSENet_Head',
in_channels=1024,
hidden_dim=256,
num_classes=7,
loss_text=dict(
type='DiceLoss',
loss_weight=0.7
),
loss_kernel=dict(
type='DiceLoss',
loss_weight=0.3
)
)
)
data = dict(
batch_size=16,
train=dict(
type='PSENET_CTW',
split='train',
is_transform=True,
img_size=736,
short_size=736,
kernel_num=7,
min_scale=0.7,
read_type='cv2'
),
test=dict(
type='PSENET_CTW',
split='test',
short_size=736,
read_type='cv2'
)
)
train_cfg = dict(
lr=1e-3,
schedule=(200, 400,),
epoch=600,
optimizer='SGD'
)
test_cfg = dict(
min_score=0.85,
min_area=16,
kernel_num=7,
bbox_type='rect',
result_path='outputs/submit_ctw.zip'
)
2 changes: 1 addition & 1 deletion config/psenet/psenet_r50_ctw.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@
min_score=0.85,
min_area=16,
kernel_num=7,
bbox_type='rect',
bbox_type='poly',
result_path='outputs/submit_ctw.zip'
)
58 changes: 58 additions & 0 deletions config/psenet/psenet_r50_ic15_1024.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
model = dict(
type='PSENet',
backbone=dict(
type='resnet50',
pretrained=True
),
neck=dict(
type='FPN',
in_channels=(256, 512, 1024, 2048),
out_channels=128
),
detection_head=dict(
type='PSENet_Head',
in_channels=1024,
hidden_dim=256,
num_classes=7,
loss_text=dict(
type='DiceLoss',
loss_weight=0.7
),
loss_kernel=dict(
type='DiceLoss',
loss_weight=0.3
)
)
)
data = dict(
batch_size=16,
train=dict(
type='PSENET_IC15',
split='train',
is_transform=True,
img_size=736,
short_size=1024,
kernel_num=7,
min_scale=0.4,
read_type='cv2'
),
test=dict(
type='PSENET_IC15',
split='test',
short_size=1024,
read_type='cv2'
)
)
train_cfg = dict(
lr=1e-3,
schedule=(200, 400,),
epoch=580,
optimizer='SGD'
)
test_cfg = dict(
min_score=0.85,
min_area=16,
kernel_num=7,
bbox_type='rect',
result_path='outputs/submit_ic15.zip'
)
84 changes: 84 additions & 0 deletions dataset/psenet/check_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from psenet_ctw import PSENET_CTW
import torch
import numpy as np
import cv2
import random
import os

torch.manual_seed(123456)
torch.cuda.manual_seed(123456)
np.random.seed(123456)
random.seed(123456)


def to_rgb(img):
img = img.reshape(img.shape[0], img.shape[1], 1)
img = np.concatenate((img, img, img), axis=2) * 255
return img


def save(img_path, imgs):
if not os.path.exists('vis/'):
os.makedirs('vis/')

for i in range(len(imgs)):
imgs[i] = cv2.copyMakeBorder(imgs[i], 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0])
res = np.concatenate(imgs, axis=1)
if type(img_path) != str:
img_name = img_path[0].split('/')[-1]
else:
img_name = img_path.split('/')[-1]
print('saved %s.' % img_name)
cv2.imwrite('vis/' + img_name, res)



# data_loader = SynthLoader(split='train', is_transform=True, img_size=640, kernel_scale=0.5, short_size=640,
# for_rec=True)
# data_loader = IC15Loader(split='train', is_transform=True, img_size=736, kernel_scale=0.5, short_size=736,
# for_rec=True)
# data_loader = CombineLoader(split='train', is_transform=True, img_size=736, kernel_scale=0.5, short_size=736,
# for_rec=True)
# data_loader = TTLoader(split='train', is_transform=True, img_size=640, kernel_scale=0.8, short_size=640,
# for_rec=True, read_type='pil')
# data_loader = CombineAllLoader(split='train', is_transform=True, img_size=736, kernel_scale=0.5, short_size=736,
# for_rec=True)
data_loader = PSENET_CTW(split='test', is_transform=True, img_size=736)
# data_loader = MSRALoader(split='train', is_transform=True, img_size=736, kernel_scale=0.5, short_size=736,
# for_rec=True)
# data_loader = CTWv2Loader(split='train', is_transform=True, img_size=640, kernel_scale=0.7, short_size=640,
# for_rec=True)
# data_loader = IC15(split='train', is_transform=True, img_size=640,)

train_loader = torch.utils.data.DataLoader(
data_loader,
batch_size=1,
shuffle=False,
num_workers=0,
drop_last=True)

for batch_idx, imgs in enumerate(train_loader):
if batch_idx > 100:
break
# image_name = data_loader.img_paths[batch_idx].split('/')[-1].split('.')[0]

# print('%d/%d %s'%(batch_idx, len(train_loader), data_loader.img_paths[batch_idx]))
print('%d/%d' % (batch_idx, len(train_loader)))

img = imgs[0].numpy()
img = ((img * np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) +
np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)) * 255).astype(np.uint8)
img = np.transpose(img, (1, 2, 0))[:, :, ::-1].copy()

# gt_text = to_rgb(gt_texts[0].numpy())
# gt_kernel_0 = to_rgb(gt_kernels[0, 0].numpy())
# gt_kernel_1 = to_rgb(gt_kernels[0, 1].numpy())
# gt_kernel_2 = to_rgb(gt_kernels[0, 2].numpy())
# gt_kernel_3 = to_rgb(gt_kernels[0, 3].numpy())
# gt_kernel_4 = to_rgb(gt_kernels[0, 4].numpy())
# gt_kernel_5 = to_rgb(gt_kernels[0, 5].numpy())
# gt_text_mask = to_rgb(training_masks[0].numpy().astype(np.uint8))


# save('%d.png' % batch_idx, [img, gt_text, gt_kernel_0, gt_kernel_1, gt_kernel_2, gt_kernel_3, gt_kernel_4, gt_kernel_5, gt_text_mask])
save('%d_test.png' % batch_idx, [img])
6 changes: 3 additions & 3 deletions dataset/psenet/psenet_ctw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import scipy.io as scio
import mmcv

ctw_root_dir = './data/CTW1500/'
ctw_root_dir = './data/ctw1500/'
ctw_train_data_dir = ctw_root_dir + 'train/text_image/'
ctw_train_gt_dir = ctw_root_dir + 'train/text_label_curve/'
ctw_test_data_dir = ctw_root_dir + 'test/text_image/'
Expand Down Expand Up @@ -195,7 +195,7 @@ def __init__(self,
img_size=None,
short_size=736,
kernel_num=7,
min_scale=0.7,
min_scale=0.4,
read_type='pil',
report_speed=False):
self.split = split
Expand Down Expand Up @@ -318,6 +318,7 @@ def prepare_train_data(self, index):
)

return data
# return img, gt_text, gt_kernels, training_mask

def prepare_test_data(self, index):
img_path = self.img_paths[index]
Expand All @@ -336,7 +337,6 @@ def prepare_test_data(self, index):
img = img.convert('RGB')
img = transforms.ToTensor()(img)
img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)

data = dict(
imgs=img,
img_metas=img_meta
Expand Down
2 changes: 1 addition & 1 deletion eval/ctw/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
project_root = '../../'

pred_root = project_root + 'outputs/submit_ctw'
gt_root = project_root + 'data/CTW1500/test/text_label_circum/'
gt_root = project_root + 'data/ctw1500/test/text_label_circum/'


def get_pred(path):
Expand Down
2 changes: 1 addition & 1 deletion eval/eval_ic15.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
cd ic15 && python2 script.py -g=gt.zip -s=../../outputs/submit_ic15_2.zip && cd ..
cd ic15 && python2 script.py -g=gt.zip -s=../../outputs/submit_ic15.zip && cd ..
4 changes: 2 additions & 2 deletions eval/ic15/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def evaluation_imports():
"""
evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation.
"""
return {
'Polygon':'plg',
Expand Down Expand Up @@ -75,7 +75,7 @@ def polygon_from_points(points):
# resBoxes[0,3]=int(points[6])
# resBoxes[0,7]=int(points[7])
# pointMat = resBoxes[0].reshape([2,4]).T
# return plg.Polygon( pointMat)
# return plg.Polygon( pointMat)

p = np.array(points)
p = p.reshape(p.shape[0]//2, 2)
Expand Down
57 changes: 38 additions & 19 deletions models/head/psenet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,49 @@ def get_results(self, out, img_meta, cfg):
start = time.time()

score = torch.sigmoid(out[:, 0, :, :])
# out = (torch.sign(out - 1) + 1) / 2 # 0 1
#
# text_mask = out[:, 0, :, :]
# kernels = out[:, 1:cfg.test_cfg.kernel_num, :, :] * text_mask

kernels = out[:, :cfg.test_cfg.kernel_num, :, :] > 0
text_mask = kernels[:, :1, :, :]
kernels[:, 1:, :, :] = kernels[:, 1:, :, :] * text_mask

score = score.data.cpu().numpy()[0].astype(np.float32)
kernels = kernels.data.cpu().numpy()[0].astype(np.uint8)
# kernel_1 = kernels[1]
# kernel_2 = kernels[2]
# kernel_3 = kernels[3]
# kernel_4 = kernels[4]
# kernel_5 = kernels[5]
# kernel_6 = kernels[6]
#
# kernel_1 = kernel_1.reshape(736, 1120, 1)
# kernel_2 = kernel_2.reshape(736, 1120, 1)
# kernel_3 = kernel_3.reshape(736, 1120, 1)
# kernel_4 = kernel_4.reshape(736, 1120, 1)
# kernel_5 = kernel_5.reshape(736, 1120, 1)
# kernel_6 = kernel_6.reshape(736, 1120, 1)
#
# kernel_1 = np.concatenate((kernel_1, kernel_1, kernel_1), axis=2) * 255
# kernel_2 = np.concatenate((kernel_2, kernel_2, kernel_2), axis=2) * 255
# kernel_3 = np.concatenate((kernel_3, kernel_3, kernel_3), axis=2) * 255
# kernel_4 = np.concatenate((kernel_4, kernel_4, kernel_4), axis=2) * 255
# kernel_5 = np.concatenate((kernel_5, kernel_5, kernel_5), axis=2) * 255
# kernel_6 = np.concatenate((kernel_6, kernel_6, kernel_6), axis=2) * 255
#
# kernel_1 = cv2.copyMakeBorder(kernel_1, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0])
# kernel_2 = cv2.copyMakeBorder(kernel_2, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0])
# kernel_3 = cv2.copyMakeBorder(kernel_3, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0])
# kernel_4 = cv2.copyMakeBorder(kernel_4, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0])
# kernel_5 = cv2.copyMakeBorder(kernel_5, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0])
# kernel_6 = cv2.copyMakeBorder(kernel_6, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0])
#
# res = np.concatenate((kernel_1, kernel_2, kernel_3, kernel_4, kernel_5, kernel_6), axis=1)
# print('saved kernels.')
# cv2.imwrite('vis_kernels.png', res)
# exit()

label = pse(kernels, cfg.test_cfg.min_area)

Expand All @@ -75,12 +112,6 @@ def get_results(self, out, img_meta, cfg):
scale = (float(org_img_size[1]) / float(img_size[1]),
float(org_img_size[0]) / float(img_size[0]))

with_rec = hasattr(cfg.model, 'recognition_head')

if with_rec:
bboxes_h = np.zeros((1, label_num, 4), dtype=np.int32)
instances = [[]]

bboxes = []
scores = []
for i in range(1, label_num):
Expand All @@ -96,19 +127,13 @@ def get_results(self, out, img_meta, cfg):
label[ind] = 0
continue

if with_rec:
tl = np.min(points, axis=0)
br = np.max(points, axis=0) + 1
bboxes_h[0, i] = (tl[0], tl[1], br[0], br[1])
instances[0].append(i)

if cfg.test_cfg.bbox_type == 'rect':
rect = cv2.minAreaRect(points[:, ::-1])
bbox = cv2.boxPoints(rect) * scale
elif cfg.test_cfg.bbox_type == 'poly':
binary = np.zeros(label.shape, dtype='uint8')
binary[ind] = 1
_, contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bbox = contours[0] * scale

bbox = bbox.astype('int32')
Expand All @@ -119,12 +144,6 @@ def get_results(self, out, img_meta, cfg):
bboxes=bboxes,
scores=scores
))
if with_rec:
outputs.update(dict(
label=label,
bboxes_h=bboxes_h,
instances=instances
))

return outputs

Expand Down
Loading

0 comments on commit 5c45dfa

Please sign in to comment.