forked from WenmuZhou/PytorchOCR
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request WenmuZhou#170 from Bourne-M/master
优化后处理,解决训练不收敛问题
- Loading branch information
Showing
1 changed file
with
121 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,121 +1,134 @@ | ||
import cv2 | ||
import numpy as np | ||
import pyclipper | ||
from shapely.geometry import Polygon | ||
|
||
|
||
class DBPostProcess(): | ||
def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5): | ||
class DBPostProcess: | ||
def __init__(self, thresh=0.3, unclip_ratio=1.5, box_thresh=0.6): | ||
self.min_size = 3 | ||
self.thresh = thresh | ||
self.box_thresh = box_thresh | ||
self.max_candidates = max_candidates | ||
self.unclip_ratio = unclip_ratio | ||
|
||
def __call__(self, pred, h_w_list, is_output_polygon=False): | ||
''' | ||
batch: (image, polygons, ignore_tags | ||
h_w_list: 包含[h,w]的数组 | ||
pred: | ||
binary: text region segmentation map, with shape (N, 1,H, W) | ||
''' | ||
pred = pred[:, 0, :, :] | ||
segmentation = self.binarize(pred) | ||
boxes_batch = [] | ||
scores_batch = [] | ||
for batch_index in range(pred.shape[0]): | ||
height, width = h_w_list[batch_index] | ||
boxes, scores = self.post_p(pred[batch_index], segmentation[batch_index], width, height, | ||
is_output_polygon=is_output_polygon) | ||
boxes_batch.append(boxes) | ||
scores_batch.append(scores) | ||
return boxes_batch, scores_batch | ||
self.bbox_scale_ratio = unclip_ratio | ||
self.shortest_length = 5 | ||
|
||
def binarize(self, pred): | ||
return pred > self.thresh | ||
|
||
def post_p(self, pred, bitmap, dest_width, dest_height, is_output_polygon=False): | ||
''' | ||
_bitmap: single map with shape (H, W), | ||
whose values are binarized as {0, 1} | ||
''' | ||
height, width = pred.shape | ||
boxes = [] | ||
new_scores = [] | ||
def __call__(self, _predict_score, _ori_img_shape): | ||
instance_score = _predict_score.squeeze() | ||
h, w = instance_score.shape[:2] | ||
height, width = _ori_img_shape[0] | ||
available_region = np.zeros_like(instance_score, dtype=np.float32) | ||
np.putmask(available_region, instance_score > self.thresh, instance_score) | ||
to_return_boxes = [] | ||
to_return_scores = [] | ||
mask_region = (available_region > 0).astype(np.uint8) * 255 | ||
structure_element = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7)) | ||
refined_mask_region = cv2.morphologyEx(mask_region, cv2.MORPH_CLOSE, structure_element) | ||
if cv2.__version__.startswith('3'): | ||
_, contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | ||
if cv2.__version__.startswith('4'): | ||
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | ||
for contour in contours[:self.max_candidates]: | ||
epsilon = 0.005 * cv2.arcLength(contour, True) | ||
approx = cv2.approxPolyDP(contour, epsilon, True) | ||
points = approx.reshape((-1, 2)) | ||
if points.shape[0] < 4: | ||
continue | ||
score = self.box_score_fast(pred, contour.squeeze(1)) | ||
if self.box_thresh > score: | ||
_, contours, _ = cv2.findContours(refined_mask_region, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | ||
elif cv2.__version__.startswith('4'): | ||
contours, _ = cv2.findContours(refined_mask_region, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | ||
else: | ||
raise NotImplementedError(f'opencv {cv2.__version__} not support') | ||
tmp_points = [] | ||
tmp_socre = [] | ||
for m_contour in contours: | ||
if len(m_contour) < 4 and cv2.contourArea(m_contour) < 16: | ||
continue | ||
if points.shape[0] > 2: | ||
box = self.unclip(points, unclip_ratio=self.unclip_ratio) | ||
if len(box) > 1: | ||
continue | ||
else: | ||
m_rotated_box = get_min_area_bbox(refined_mask_region, m_contour, self.bbox_scale_ratio) | ||
if m_rotated_box is None: | ||
continue | ||
four_point_box, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) | ||
if sside < self.min_size + 2: | ||
m_box_width = m_rotated_box['box_width'] | ||
m_box_height = m_rotated_box['box_height'] | ||
if min(m_box_width * w, m_box_height * h) < self.shortest_length: | ||
continue | ||
if not isinstance(dest_width, int): | ||
dest_width = dest_width.item() | ||
dest_height = dest_height.item() | ||
if not is_output_polygon: | ||
box = np.array(four_point_box) | ||
else: | ||
box = box.reshape(-1, 2) | ||
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) | ||
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) | ||
boxes.append(box) | ||
new_scores.append(score) | ||
return boxes, new_scores | ||
rotated_points = get_coordinates_of_rotated_box(m_rotated_box, height, width) | ||
tmp_points.append(rotated_points) | ||
|
||
def unclip(self, box, unclip_ratio=1.5): | ||
poly = Polygon(box) | ||
distance = poly.area * unclip_ratio / poly.length | ||
offset = pyclipper.PyclipperOffset() | ||
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) | ||
expanded = np.array(offset.Execute(distance)) | ||
return expanded | ||
m_available_mask = np.zeros_like(available_region, dtype=np.uint8) | ||
cv2.drawContours(m_available_mask, [m_contour, ], 0, 255, thickness=-1) | ||
m_region_mask = cv2.bitwise_and(available_region, available_region, mask=m_available_mask) | ||
m_mask_count = np.count_nonzero(m_available_mask) | ||
tmp_socre.append(float(np.sum(m_region_mask) / m_mask_count)) | ||
|
||
def get_mini_boxes(self, contour): | ||
bounding_box = cv2.minAreaRect(contour) | ||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) | ||
to_return_boxes.append(tmp_points) | ||
to_return_scores.append(tmp_socre) | ||
|
||
return to_return_boxes, to_return_scores | ||
|
||
|
||
def rotate_points(_points, _degree=0, _center=(0, 0)): | ||
""" | ||
逆时针绕着一个点旋转点 | ||
Args: | ||
_points: 需要旋转的点 | ||
_degree: 角度 | ||
_center: 中心点 | ||
Returns: 旋转后的点 | ||
""" | ||
angle = np.deg2rad(_degree) | ||
rotate_matrix = np.array([[np.cos(angle), -np.sin(angle)], | ||
[np.sin(angle), np.cos(angle)]]) | ||
center = np.atleast_2d(_center) | ||
points = np.atleast_2d(_points) | ||
return np.squeeze((rotate_matrix @ (points.T - center.T) + center.T).T) | ||
|
||
index_1, index_2, index_3, index_4 = 0, 1, 2, 3 | ||
if points[1][1] > points[0][1]: | ||
index_1 = 0 | ||
index_4 = 1 | ||
else: | ||
index_1 = 1 | ||
index_4 = 0 | ||
if points[3][1] > points[2][1]: | ||
index_2 = 2 | ||
index_3 = 3 | ||
else: | ||
index_2 = 3 | ||
index_3 = 2 | ||
|
||
box = [points[index_1], points[index_2], points[index_3], points[index_4]] | ||
return box, min(bounding_box[1]) | ||
def get_coordinates_of_rotated_box(_rotated_box, _height, _width): | ||
""" | ||
获取旋转的矩形的对应的四个顶点坐标 | ||
Args: | ||
_image: 对应的图像 | ||
_rotated_box: 旋转的矩形 | ||
Returns: 四个对应在图像中的坐标点 | ||
""" | ||
center_x = _rotated_box['center_x'] | ||
center_y = _rotated_box['center_y'] | ||
half_box_width = _rotated_box['box_width'] / 2 | ||
half_box_height = _rotated_box['box_height'] / 2 | ||
raw_points = np.array([ | ||
[center_x - half_box_width, center_y - half_box_height], | ||
[center_x + half_box_width, center_y - half_box_height], | ||
[center_x + half_box_width, center_y + half_box_height], | ||
[center_x - half_box_width, center_y + half_box_height] | ||
]) * (_width, _height) | ||
rotated_points = rotate_points(raw_points, _rotated_box['degree'], (center_x * _width, center_y * _height)) | ||
rotated_points[:, 0] = np.clip(rotated_points[:, 0], a_min=0, a_max=_width) | ||
rotated_points[:, 1] = np.clip(rotated_points[:, 1], a_min=0, a_max=_height) | ||
return rotated_points.astype(np.int32) | ||
|
||
def box_score_fast(self, bitmap, _box): | ||
h, w = bitmap.shape[:2] | ||
box = _box.copy() | ||
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) | ||
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) | ||
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) | ||
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) | ||
|
||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) | ||
box[:, 0] = box[:, 0] - xmin | ||
box[:, 1] = box[:, 1] - ymin | ||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) | ||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] | ||
def get_min_area_bbox(_image, _contour, _scale_ratio=1.0): | ||
""" | ||
获取一个contour对应的最小面积矩形 | ||
note:主要是解决了旋转角度不合适的问题 | ||
Args: | ||
_image: bbox所在图像 | ||
_contour: 轮廓 | ||
_scale_ratio: 缩放比例 | ||
Returns: 最小面积矩形的相关信息 | ||
""" | ||
h, w = _image.shape[:2] | ||
if _scale_ratio != 1: | ||
reshaped_contour = _contour.reshape(-1, 2) | ||
current_polygon = Polygon(reshaped_contour) | ||
distance = current_polygon.area * _scale_ratio / current_polygon.length | ||
offset = PyclipperOffset() | ||
offset.AddPath(reshaped_contour, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) | ||
box = offset.Execute(distance) | ||
if len(box) == 0 or len(box) > 1: | ||
return None | ||
scaled_contour = np.array(box).reshape(-1, 1, 2) | ||
else: | ||
scaled_contour = _contour | ||
try: | ||
rotated_box = cv2.minAreaRect(scaled_contour) | ||
except Exception: | ||
return None | ||
if -90 <= rotated_box[2] <= -45: | ||
to_rotate_degree = rotated_box[2] + 90 | ||
bbox_height, bbox_width = rotated_box[1] | ||
else: | ||
to_rotate_degree = rotated_box[2] | ||
bbox_width, bbox_height = rotated_box[1] | ||
# 几何信息归一化可以方便进行在缩放前的图像上进行操作 | ||
to_return_rotated_box = { | ||
'degree': int(to_rotate_degree), | ||
'center_x': rotated_box[0][0] / w, | ||
'center_y': rotated_box[0][1] / h, | ||
'box_height': bbox_height / h, | ||
'box_width': bbox_width / w, | ||
} | ||
return to_return_rotated_box |