Skip to content

Commit

Permalink
Merge pull request fizyr#246 from fizyr/fix-transforms
Browse files Browse the repository at this point in the history
Fix applying transforms to images.
  • Loading branch information
de-vri-es authored Jan 31, 2018
2 parents ed7a00a + 9859954 commit 8fca81f
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 52 deletions.
52 changes: 31 additions & 21 deletions keras_retinanet/bin/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@
def create_generator(args):
# create random transform generator for augmenting training data
transform_generator = random_transform_generator(
# min_rotation=-0.1,
# max_rotation=0.1,
# min_translation=(-0.1, -0.1),
# max_translation=(0.1, 0.1),
# min_shear=-0.1,
# max_shear=0.1,
# min_scaling=(0.9, 0.9),
# max_scaling=(1.1, 1.1),
min_rotation=-0.1,
max_rotation=0.1,
min_translation=(-0.1, -0.1),
max_translation=(0.1, 0.1),
min_shear=-0.1,
max_shear=0.1,
min_scaling=(0.9, 0.9),
max_scaling=(1.1, 1.1),
flip_x_chance=0.5,
flip_y_chance=0.5,
)
Expand Down Expand Up @@ -93,25 +93,15 @@ def parse_args(args):
csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for evaluation.')
csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')

parser.add_argument('-l', '--loop', help='Loop forever, even if the dataset is exhausted.', action='store_true')
parser.add_argument('--no-resize', help='Disable image resizing.', dest='resize', action='store_false')
parser.add_argument('--annotations', help='Show annotations on the image.', action='store_true')
parser.add_argument('--random-transform', help='Randomly transform image and annotations.', action='store_true')

return parser.parse_args(args)


def main(args=None):
# parse arguments
if args is None:
args = sys.argv[1:]
args = parse_args(args)

# create the generator
generator = create_generator(args)

# create the display window
cv2.namedWindow('Image', cv2.WINDOW_NORMAL)

def run(generator, args):
# display images, one at a time
for i in range(generator.size()):
# load the data
Expand All @@ -134,7 +124,27 @@ def main(args=None):

cv2.imshow('Image', image)
if cv2.waitKey() == ord('q'):
break
return False
return True


def main(args=None):
# parse arguments
if args is None:
args = sys.argv[1:]
args = parse_args(args)

# create the generator
generator = create_generator(args)

# create the display window
cv2.namedWindow('Image', cv2.WINDOW_NORMAL)

if args.loop:
while run(generator, args):
pass
else:
run(generator, args)

if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion keras_retinanet/preprocessing/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def random_transform_group_entry(self, image, annotations):
# randomly transform both image and annotations
if self.transform_generator:
transform = adjust_transform_for_image(next(self.transform_generator), image, self.transform_parameters.relative_translation)
image = np.swapaxes(apply_transform(transform, np.swapaxes(image, 0, 1), self.transform_parameters), 0, 1)
image = apply_transform(transform, image, self.transform_parameters)

# Transform the bounding boxes in the annotations.
annotations = annotations.copy()
Expand Down
67 changes: 56 additions & 11 deletions keras_retinanet/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

from __future__ import division
import keras
import keras.preprocessing.image
import time
import numpy as np
import scipy.ndimage as ndi
import cv2
import PIL

from .transform import change_transform_origin, transform_aabb, colvec
from .transform import change_transform_origin, transform_aabb


def read_image_bgr(path):
Expand Down Expand Up @@ -75,21 +75,24 @@ class TransformParameters:
""" Struct holding parameters determining how to apply a transformation to an image.
# Arguments
fill_mode: Same as for keras.preprocessing.image.apply_transform
cval: Same as for keras.preprocessing.image.apply_transform
fill_mode: One of: 'constant', 'nearest', 'reflect', 'wrap'
interpolation: One of: 'nearest', 'linear', 'cubic', 'area', 'lanczos4'
cval: Fill value to use with fill_mode='constant'
data_format: Same as for keras.preprocessing.image.apply_transform
relative_translation: If true (the default), interpret translation as a factor of the image size.
If false, interpret it as absolute pixels.
"""
def __init__(
self,
fill_mode = 'nearest',
interpolation = 'linear',
cval = 0,
data_format = None,
relative_translation = True,
):
self.fill_mode = fill_mode
self.cval = cval
self.interpolation = interpolation
self.relative_translation = relative_translation

if data_format is None:
Expand All @@ -103,17 +106,59 @@ def __init__(
else:
raise ValueError("invalid data_format, expected 'channels_first' or 'channels_last', got '{}'".format(data_format))

def cvBorderMode(self):
if self.fill_mode == 'constant':
return cv2.BORDER_CONSTANT
if self.fill_mode == 'nearest':
return cv2.BORDER_REPLICATE
if self.fill_mode == 'reflect':
return cv2.BORDER_REFLECT_101
if self.fill_mode == 'wrap':
return cv2.BORDER_WRAP

def cvInterpolation(self):
if self.interpolation == 'nearest':
return cv2.INTER_NEAREST
if self.interpolation == 'linear':
return cv2.INTER_LINEAR
if self.interpolation == 'cubic':
return cv2.INTER_CUBIC
if self.interpolation == 'area':
return cv2.INTER_AREA
if self.interpolation == 'lanczos4':
return cv2.INTER_LANCZOS4


def apply_transform(matrix, image, params):
"""
Apply a transformation to an image.
The origin of transformation is at the top left corner of the image.
The matrix is interpreted such that a point (x, y) on the original image is moved to transform * (x, y) in the generated image.
Mathematically speaking, that means that the matrix is a transformation from the transformed image space to the original image space.
def apply_transform(transform, image, params):
""" Wrapper around keras.preprocessing.image.apply_transform using TransformParameters. """
return keras.preprocessing.image.apply_transform(
Parameters:
matrix: A homogenous 3 by 3 matrix holding representing the transformation to apply.
image: The image to transform.
params: The transform parameters (see TransformParameters)
"""
if params.channel_axis != 2:
image = np.moveaxis(image, params.channel_axis, 2)

output = cv2.warpAffine(
image,
transform,
channel_axis = params.channel_axis,
fill_mode = params.fill_mode,
cval = params.cval
matrix[:2, :],
dsize = (image.shape[1], image.shape[0]),
flags = params.cvInterpolation(),
borderMode = params.cvBorderMode(),
borderValue = params.cval,
)

if params.channel_axis != 2:
output = np.moveaxis(output, 2, params.channel_axis)
return output


def resize_image(img, min_side=600, max_side=1024):
(rows, cols, _) = img.shape
Expand Down
36 changes: 18 additions & 18 deletions keras_retinanet/utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,26 +171,27 @@ def random_flip(flip_x_chance, flip_y_chance, prng=DEFAULT_PRNG):


def change_transform_origin(transform, center):
""" Create a new transform with the origin at a different location.
""" Create a new transform representing the same transformation,
only with the origin of the linear part changed.
# Arguments:
transform: the transformation matrix
center: the new origin of the transformation
# Return:
translate(center) * transform * translate(-center)
"""
center = np.array(center)
return np.dot(np.dot(translation(center), transform), translation(-center))
return np.linalg.multi_dot([translation(center), transform, translation(-center)])


def random_transform(
# min_rotation=0,
# max_rotation=0,
# min_translation=(0, 0),
# max_translation=(0, 0),
# min_shear=0,
# max_shear=0,
# min_scaling=(1, 1),
# max_scaling=(1, 1),
min_rotation=0,
max_rotation=0,
min_translation=(0, 0),
max_translation=(0, 0),
min_shear=0,
max_shear=0,
min_scaling=(1, 1),
max_scaling=(1, 1),
flip_x_chance=0,
flip_y_chance=0,
prng=DEFAULT_PRNG
Expand Down Expand Up @@ -223,14 +224,13 @@ def random_transform(
flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction.
prng: The pseudo-random number generator to use.
"""
# return np.linalg.multi_dot([
# random_rotation(min_rotation, max_rotation, prng),
# random_translation(min_translation, max_translation, prng),
# random_shear(min_shear, max_shear, prng),
# random_scaling(min_scaling, max_scaling, prng),
# random_flip(flip_x_chance, flip_y_chance, prng)
# ])
return random_flip(flip_x_chance, flip_y_chance, prng)
return np.linalg.multi_dot([
random_rotation(min_rotation, max_rotation, prng),
random_translation(min_translation, max_translation, prng),
random_shear(min_shear, max_shear, prng),
random_scaling(min_scaling, max_scaling, prng),
random_flip(flip_x_chance, flip_y_chance, prng)
])


def random_transform_generator(prng=None, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
maintainer='Hans Gaiser',
maintainer_email='[email protected]',
packages=setuptools.find_packages(),
install_requires=['keras', 'keras-resnet', 'six'],
install_requires=['keras', 'keras-resnet', 'six', 'scipy'],
entry_points = {
'console_scripts': [
'retinanet-train=keras_retinanet.bin.train:main',
Expand Down

0 comments on commit 8fca81f

Please sign in to comment.