Skip to content

Commit

Permalink
Updatre tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeySrus committed May 27, 2022
1 parent afa8b26 commit 97ddf28
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions inference_utils/tiled_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


def tiling_intersected(
img: np.ndarray, tile_size: int, step: float = 3/4) -> List[List[Tuple[int, int]]]:
img: np.ndarray,
tile_size: int,
step: float = 1/2) -> List[List[Tuple[int, int]]]:
stride = int(tile_size * step)

x0_vec = []
Expand Down Expand Up @@ -37,17 +39,15 @@ def tiling_intersected(


class PolyDetection(object):
def __init__(self, _mask: np.ndarray, class_num: int):
def __init__(self, _mask: np.ndarray, class_num: int, position: Tuple[int, int]):
self.cls = class_num

wrapped_mask = Mask(_mask)
segmentation_data = wrapped_mask.polygons().segmentation

assert len(segmentation_data) > 0, 'Empty polygon of class {}'.format(self.cls)

print(segmentation_data.shape)

segmentation_data = segmentation_data[0]
segmentation_data = max(segmentation_data, key=lambda _x: len(_x))

if len(segmentation_data) >= 7*5:
segmentation_array = np.array(
Expand All @@ -56,11 +56,11 @@ def __init__(self, _mask: np.ndarray, class_num: int):
segmentation_array = np.array(
segmentation_data).reshape((-1, 2)).astype(np.float32)

self.poly = Polygon(segmentation_array)
self.poly = Polygon(segmentation_array + position)

def estimate_iou(self, other) -> float:
intersection = self.poly.intersection(other)
union = self.poly.union(other)
intersection = self.poly.intersection(other.poly)
union = self.poly.union(other.poly)

return intersection.area / (union.area + 1E-5)

Expand All @@ -72,21 +72,21 @@ class DetectionsCarrier(object):
def merge_carriers(
scope_src: DetectionsCarrier,
scope_to_add: DetectionsCarrier,
match_threshold: float = 0.75):
pairwise_intersections = [
match_threshold: float = 0.45):
pairwise_intersections = np.array([
[
scope_to_add.detections[si].estimate_iou(scope_src.detections[sj])
for sj in range(len(scope_src.detections))
]
for si in range(len(scope_to_add.detections))
]
])

for si in range(len(scope_to_add.detections)):
search_idx = -1
best_iou = 0
diff_poly = Polygon(scope_to_add.detections[si].poly)

for sj in range(len(scope_src.detections)):
for sj in range(pairwise_intersections.shape[1]):
p_iou = pairwise_intersections[si][sj]
diff_poly = diff_poly.difference(scope_src.detections[sj].poly)

Expand All @@ -98,7 +98,7 @@ def merge_carriers(
if search_idx == -1:
scope_src.detections.append(scope_to_add.detections[si])
else:
scope_src.detections[search_idx].poly = scope_src.detections[search_idx].poly.union(scope_to_add.detections[si])
scope_src.detections[search_idx].poly = scope_src.detections[search_idx].poly.union(scope_to_add.detections[si].poly)


class ImageSegment(DetectionsCarrier):
Expand All @@ -114,7 +114,7 @@ def __init__(self,
masks, _, pred_classes = inference_function(self.get_crop())

self.detections = [
PolyDetection(masks[_mi], pred_classes[_mi])
PolyDetection(masks[_mi], pred_classes[_mi], position)
for _mi in range(len(masks))
]

Expand Down

0 comments on commit 97ddf28

Please sign in to comment.