forked from whai362/PSENet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
84 changed files
with
24,734 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from dataset.icdar2015_loader import IC15Loader | ||
from dataset.icdar2015_test_loader import IC15TestLoader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
# dataloader add 3.0 scale | ||
# dataloader add filer text | ||
import numpy as np | ||
from PIL import Image | ||
from torch.utils import data | ||
import util | ||
import cv2 | ||
import random | ||
import torchvision.transforms as transforms | ||
import torch | ||
import pyclipper | ||
import Polygon as plg | ||
|
||
ic15_root_dir = './data/ICDAR2015/Challenge4/' | ||
ic15_train_data_dir = ic15_root_dir + 'ch4_training_images/' | ||
ic15_train_gt_dir = ic15_root_dir + 'ch4_training_localization_transcription_gt/' | ||
ic15_test_data_dir = ic15_root_dir + 'ch4_test_images/' | ||
ic15_test_gt_dir = ic15_root_dir + 'ch4_test_localization_transcription_gt/' | ||
|
||
random.seed(123456) | ||
|
||
def get_img(img_path): | ||
try: | ||
img = cv2.imread(img_path) | ||
img = img[:, :, [2, 1, 0]] | ||
except Exception as e: | ||
print img_path | ||
raise | ||
return img | ||
|
||
def get_bboxes(img, gt_path): | ||
h, w = img.shape[0:2] | ||
lines = util.io.read_lines(gt_path) | ||
bboxes = [] | ||
tags = [] | ||
for line in lines: | ||
line = util.str.remove_all(line, '\xef\xbb\xbf') | ||
gt = util.str.split(line, ',') | ||
if gt[-1][0] == '#': | ||
tags.append(False) | ||
else: | ||
tags.append(True) | ||
box = [int(gt[i]) for i in range(8)] | ||
box = np.asarray(box) / ([w * 1.0, h * 1.0] * 4) | ||
bboxes.append(box) | ||
return np.array(bboxes), tags | ||
|
||
def random_horizontal_flip(imgs): | ||
if random.random() < 0.5: | ||
for i in range(len(imgs)): | ||
imgs[i] = np.flip(imgs[i], axis=1).copy() | ||
return imgs | ||
|
||
def random_rotate(imgs): | ||
max_angle = 10 | ||
angle = random.random() * 2 * max_angle - max_angle | ||
for i in range(len(imgs)): | ||
img = imgs[i] | ||
w, h = img.shape[:2] | ||
rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) | ||
img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w)) | ||
imgs[i] = img_rotation | ||
return imgs | ||
|
||
def scale(img, long_size=2240): | ||
h, w = img.shape[0:2] | ||
scale = long_size * 1.0 / max(h, w) | ||
img = cv2.resize(img, dsize=None, fx=scale, fy=scale) | ||
return img | ||
|
||
def random_scale(img, min_size): | ||
h, w = img.shape[0:2] | ||
if max(h, w) > 1280: | ||
scale = 1280.0 / max(h, w) | ||
img = cv2.resize(img, dsize=None, fx=scale, fy=scale) | ||
|
||
h, w = img.shape[0:2] | ||
random_scale = np.array([0.5, 1.0, 2.0, 3.0]) | ||
scale = np.random.choice(random_scale) | ||
if min(h, w) * scale <= min_size: | ||
scale = (min_size + 10) * 1.0 / min(h, w) | ||
img = cv2.resize(img, dsize=None, fx=scale, fy=scale) | ||
return img | ||
|
||
def random_crop(imgs, img_size): | ||
h, w = imgs[0].shape[0:2] | ||
th, tw = img_size | ||
if w == tw and h == th: | ||
return imgs | ||
|
||
if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: | ||
tl = np.min(np.where(imgs[1] > 0), axis = 1) - img_size | ||
tl[tl < 0] = 0 | ||
br = np.max(np.where(imgs[1] > 0), axis = 1) - img_size | ||
br[br < 0] = 0 | ||
br[0] = min(br[0], h - th) | ||
br[1] = min(br[1], w - tw) | ||
|
||
i = random.randint(tl[0], br[0]) | ||
j = random.randint(tl[1], br[1]) | ||
else: | ||
i = random.randint(0, h - th) | ||
j = random.randint(0, w - tw) | ||
|
||
# return i, j, th, tw | ||
for idx in range(len(imgs)): | ||
if len(imgs[idx].shape) == 3: | ||
imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] | ||
else: | ||
imgs[idx] = imgs[idx][i:i + th, j:j + tw] | ||
return imgs | ||
|
||
def dist(a, b): | ||
return np.sqrt(np.sum((a - b) ** 2)) | ||
|
||
def perimeter(bbox): | ||
peri = 0.0 | ||
for i in range(bbox.shape[0]): | ||
peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) | ||
return peri | ||
|
||
def shrink(bboxes, rate, max_shr=20): | ||
rate = rate * rate | ||
shrinked_bboxes = [] | ||
for bbox in bboxes: | ||
area = plg.Polygon(bbox).area() | ||
peri = perimeter(bbox) | ||
|
||
pco = pyclipper.PyclipperOffset() | ||
pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) | ||
offset = min((int)(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) | ||
|
||
shrinked_bbox = pco.Execute(-offset) | ||
if len(shrinked_bbox) == 0: | ||
shrinked_bboxes.append(bbox) | ||
continue | ||
|
||
shrinked_bbox = np.array(shrinked_bbox)[0] | ||
if shrinked_bbox.shape[0] <= 2: | ||
shrinked_bboxes.append(bbox) | ||
continue | ||
|
||
shrinked_bboxes.append(shrinked_bbox) | ||
|
||
return np.array(shrinked_bboxes) | ||
|
||
class IC15Loader(data.Dataset): | ||
def __init__(self, is_transform=False, img_size=None, kernel_num=7, min_scale=0.4): | ||
self.is_transform = is_transform | ||
|
||
self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size) | ||
self.kernel_num = kernel_num | ||
self.min_scale = min_scale | ||
|
||
data_dirs = [ic15_train_data_dir] | ||
gt_dirs = [ic15_train_gt_dir] | ||
|
||
self.img_paths = [] | ||
self.gt_paths = [] | ||
|
||
for data_dir, gt_dir in zip(data_dirs, gt_dirs): | ||
img_names = util.io.ls(data_dir, '.jpg') | ||
img_names.extend(util.io.ls(data_dir, '.png')) | ||
# img_names.extend(util.io.ls(data_dir, '.gif')) | ||
|
||
img_paths = [] | ||
gt_paths = [] | ||
for idx, img_name in enumerate(img_names): | ||
img_path = data_dir + img_name | ||
img_paths.append(img_path) | ||
|
||
gt_name = 'gt_' + img_name.split('.')[0] + '.txt' | ||
gt_path = gt_dir + gt_name | ||
gt_paths.append(gt_path) | ||
|
||
self.img_paths.extend(img_paths) | ||
self.gt_paths.extend(gt_paths) | ||
|
||
def __len__(self): | ||
return len(self.img_paths) | ||
|
||
def __getitem__(self, index): | ||
img_path = self.img_paths[index] | ||
gt_path = self.gt_paths[index] | ||
|
||
img = get_img(img_path) | ||
bboxes, tags = get_bboxes(img, gt_path) | ||
|
||
if self.is_transform: | ||
img = random_scale(img, self.img_size[0]) | ||
|
||
gt_text = np.zeros(img.shape[0:2], dtype='uint8') | ||
training_mask = np.ones(img.shape[0:2], dtype='uint8') | ||
if bboxes.shape[0] > 0: | ||
bboxes = np.reshape(bboxes * ([img.shape[1], img.shape[0]] * 4), (bboxes.shape[0], bboxes.shape[1] / 2, 2)).astype('int32') | ||
for i in range(bboxes.shape[0]): | ||
cv2.drawContours(gt_text, [bboxes[i]], -1, i + 1, -1) | ||
if not tags[i]: | ||
cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) | ||
|
||
gt_kernels = [] | ||
for i in range(1, self.kernel_num): | ||
rate = 1.0 - (1.0 - self.min_scale) / (self.kernel_num - 1) * i | ||
gt_kernel = np.zeros(img.shape[0:2], dtype='uint8') | ||
kernel_bboxes = shrink(bboxes, rate) | ||
for i in range(bboxes.shape[0]): | ||
cv2.drawContours(gt_kernel, [kernel_bboxes[i]], -1, 1, -1) | ||
gt_kernels.append(gt_kernel) | ||
|
||
if self.is_transform: | ||
imgs = [img, gt_text, training_mask] | ||
imgs.extend(gt_kernels) | ||
|
||
imgs = random_horizontal_flip(imgs) | ||
imgs = random_rotate(imgs) | ||
imgs = random_crop(imgs, self.img_size) | ||
|
||
img, gt_text, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:] | ||
|
||
gt_text[gt_text > 0] = 1 | ||
gt_kernels = np.array(gt_kernels) | ||
|
||
# ''' | ||
if self.is_transform: | ||
img = Image.fromarray(img) | ||
img = img.convert('RGB') | ||
img = transforms.ColorJitter(brightness = 32.0 / 255, saturation = 0.5)(img) | ||
else: | ||
img = Image.fromarray(img) | ||
img = img.convert('RGB') | ||
|
||
img = transforms.ToTensor()(img) | ||
img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) | ||
|
||
gt_text = torch.from_numpy(gt_text).float() | ||
gt_kernels = torch.from_numpy(gt_kernels).float() | ||
training_mask = torch.from_numpy(training_mask).float() | ||
# ''' | ||
|
||
return img, gt_text, gt_kernels, training_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# dataloader add 3.0 scale | ||
# dataloader add filer text | ||
import numpy as np | ||
from PIL import Image | ||
from torch.utils import data | ||
import util | ||
import cv2 | ||
import random | ||
import torchvision.transforms as transforms | ||
import torch | ||
|
||
ic15_root_dir = './data/ICDAR2015/Challenge4/' | ||
ic15_test_data_dir = ic15_root_dir + 'ch4_test_images/' | ||
ic15_test_gt_dir = ic15_root_dir + 'ch4_test_localization_transcription_gt/' | ||
|
||
random.seed(123456) | ||
|
||
def get_img(img_path): | ||
try: | ||
img = cv2.imread(img_path) | ||
img = img[:, :, [2, 1, 0]] | ||
except Exception as e: | ||
print img_path | ||
raise | ||
return img | ||
|
||
def scale(img, long_size=2240): | ||
h, w = img.shape[0:2] | ||
scale = long_size * 1.0 / max(h, w) | ||
img = cv2.resize(img, dsize=None, fx=scale, fy=scale) | ||
return img | ||
|
||
class IC15TestLoader(data.Dataset): | ||
def __init__(self, part_id=0, part_num=1, long_size=2240): | ||
data_dirs = [ic15_test_data_dir] | ||
|
||
self.img_paths = [] | ||
|
||
for data_dir in data_dirs: | ||
img_names = util.io.ls(data_dir, '.jpg') | ||
img_names.extend(util.io.ls(data_dir, '.png')) | ||
|
||
img_paths = [] | ||
for idx, img_name in enumerate(img_names): | ||
img_path = data_dir + img_name | ||
img_paths.append(img_path) | ||
|
||
self.img_paths.extend(img_paths) | ||
|
||
part_size = len(self.img_paths) / part_num | ||
l = part_id * part_size | ||
r = (part_id + 1) * part_size | ||
self.img_paths = self.img_paths[l:r] | ||
self.long_size = long_size | ||
|
||
def __len__(self): | ||
return len(self.img_paths) | ||
|
||
def __getitem__(self, index): | ||
img_path = self.img_paths[index] | ||
|
||
img = get_img(img_path) | ||
|
||
scaled_imgs = [] | ||
scaled_img = scale(img, self.long_size) | ||
scaled_img = Image.fromarray(scaled_img) | ||
scaled_img = scaled_img.convert('RGB') | ||
scaled_img = transforms.ToTensor()(scaled_img) | ||
scaled_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(scaled_img) | ||
|
||
return img[:, :, [2, 1, 0]], scaled_img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
cd ic15 | ||
python script.py -g=gt.zip -s=../../outputs/submit_ic15.zip | ||
cd .. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
|
||
def read_dir(root): | ||
file_path_list = [] | ||
for file_path, dirs, files in os.walk(root): | ||
for file in files: | ||
file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) | ||
file_path_list.sort() | ||
return file_path_list | ||
|
||
def read_file(file_path): | ||
file_object = open(file_path, 'r') | ||
file_content = file_object.read() | ||
file_object.close() | ||
return file_content | ||
|
||
def write_file(file_path, file_content): | ||
if file_path.find('/') != -1: | ||
father_dir = '/'.join(file_path.split('/')[0:-1]) | ||
if not os.path.exists(father_dir): | ||
os.makedirs(father_dir) | ||
file_object = open(file_path, 'w') | ||
file_object.write(file_content) | ||
file_object.close() | ||
|
||
|
||
def write_file_not_cover(file_path, file_content): | ||
father_dir = '/'.join(file_path.split('/')[0:-1]) | ||
if not os.path.exists(father_dir): | ||
os.makedirs(father_dir) | ||
file_object = open(file_path, 'a') | ||
file_object.write(file_content) | ||
file_object.close() |
Binary file not shown.
Oops, something went wrong.