Skip to content

Commit

Permalink
A wrapper for skimage resize() to avoid warnings
Browse files Browse the repository at this point in the history
skimage generates different warnings depending on the version. This wrapper function calls skimage.tranform.resize() with the right parameter for each version.
  • Loading branch information
waleedka committed Sep 28, 2018
1 parent a5d80c7 commit 64020bd
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
9 changes: 3 additions & 6 deletions mrcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from collections import OrderedDict
import multiprocessing
import numpy as np
import skimage.transform
import tensorflow as tf
import keras
import keras.backend as K
Expand Down Expand Up @@ -1071,8 +1070,7 @@ def rpn_bbox_loss_graph(config, target_bbox, rpn_match, rpn_bbox):
target_bbox = batch_pack_graph(target_bbox, batch_counts,
config.IMAGES_PER_GPU)

#use smooth_l1_loss() rather than reimplementing here to reduce code duplication
loss = smooth_l1_loss(target_bbox,rpn_bbox)
loss = smooth_l1_loss(target_bbox, rpn_bbox)

loss = K.switch(tf.size(loss) > 0, K.mean(loss), tf.constant(0.0))
return loss
Expand Down Expand Up @@ -1434,15 +1432,14 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
gt_h = gt_y2 - gt_y1
# Resize mini mask to size of GT box
placeholder[gt_y1:gt_y2, gt_x1:gt_x2] = \
np.round(skimage.transform.resize(
class_mask, (gt_h, gt_w), order=1, mode="constant")).astype(bool)
np.round(utils.resize(class_mask, (gt_h, gt_w))).astype(bool)
# Place the mini batch in the placeholder
class_mask = placeholder

# Pick part of the mask and resize it
y1, x1, y2, x2 = rois[i].astype(np.int32)
m = class_mask[y1:y2, x1:x2]
mask = skimage.transform.resize(m, config.MASK_SHAPE, order=1, mode="constant")
mask = utils.resize(m, config.MASK_SHAPE)
masks[i, :, :, class_id] = mask

return rois, roi_gt_class_ids, bboxes, masks
Expand Down
36 changes: 30 additions & 6 deletions mrcnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import urllib.request
import shutil
import warnings
from distutils.version import LooseVersion

# URL from which to download the latest COCO trained weights
COCO_MODEL_URL = "https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5"
Expand Down Expand Up @@ -452,9 +453,8 @@ def resize_image(image, min_dim=None, max_dim=None, min_scale=None, mode="square

# Resize image using bilinear interpolation
if scale != 1:
image = skimage.transform.resize(
image, (round(h * scale), round(w * scale)),
order=1, mode="constant", preserve_range=True)
image = resize(image, (round(h * scale), round(w * scale)),
preserve_range=True)

# Need padding or cropping?
if mode == "square":
Expand Down Expand Up @@ -538,7 +538,7 @@ def minimize_mask(bbox, mask, mini_shape):
if m.size == 0:
raise Exception("Invalid bounding box with area of zero")
# Resize with bilinear interpolation
m = skimage.transform.resize(m, mini_shape, order=1, mode="constant")
m = resize(m, mini_shape)
mini_mask[:, :, i] = np.around(m).astype(np.bool)
return mini_mask

Expand All @@ -556,7 +556,7 @@ def expand_mask(bbox, mini_mask, image_shape):
h = y2 - y1
w = x2 - x1
# Resize with bilinear interpolation
m = skimage.transform.resize(m, (h, w), order=1, mode="constant")
m = resize(m, (h, w))
mask[y1:y2, x1:x2, i] = np.around(m).astype(np.bool)
return mask

Expand All @@ -576,7 +576,7 @@ def unmold_mask(mask, bbox, image_shape):
"""
threshold = 0.5
y1, x1, y2, x2 = bbox
mask = skimage.transform.resize(mask, (y2 - y1, x2 - x1), order=1, mode="constant")
mask = resize(mask, (y2 - y1, x2 - x1))
mask = np.where(mask >= threshold, 1, 0).astype(np.bool)

# Put the mask in the right location.
Expand Down Expand Up @@ -891,3 +891,27 @@ def denorm_boxes(boxes, shape):
scale = np.array([h - 1, w - 1, h - 1, w - 1])
shift = np.array([0, 0, 1, 1])
return np.around(np.multiply(boxes, scale) + shift).astype(np.int32)


def resize(image, output_shape, order=1, mode='constant', cval=0, clip=True,
preserve_range=False, anti_aliasing=False, anti_aliasing_sigma=None):
"""A wrapper for Scikit-Image resize().
Scikit-Image generates warnings on every call to resize() if it doesn't
receive the right parameters. The right parameters depend on the version
of skimage. This solves the problem by using different parameters per
version. And it provides a central place to control resizing defaults.
"""
if LooseVersion(skimage.__version__) >= LooseVersion("0.14"):
# New in 0.14: anti_aliasing. Default it to False for backward
# compatibility with skimage 0.13.
return skimage.transform.resize(
image, output_shape,
order=order, mode=mode, cval=cval, clip=clip,
preserve_range=preserve_range, anti_aliasing=anti_aliasing,
anti_aliasing_sigma=anti_aliasing_sigma)
else:
return skimage.transform.resize(
image, output_shape,
order=order, mode=mode, cval=cval, clip=clip,
preserve_range=preserve_range)

0 comments on commit 64020bd

Please sign in to comment.