Skip to content

Commit

Permalink
Add no augmentation sources
Browse files Browse the repository at this point in the history
Add the possibility to exclude some sources from augmentation by passing a list of sources. This is useful when you want to retrain a model having few images.
  • Loading branch information
Nick authored and waleedka committed Jul 12, 2018
1 parent 23c82fd commit 5202a02
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions mrcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,8 @@ def generate_random_rois(image_shape, count, gt_class_ids, gt_boxes):


def data_generator(dataset, config, shuffle=True, augment=False, augmentation=None,
random_rois=0, batch_size=1, detection_targets=False):
random_rois=0, batch_size=1, detection_targets=False,
no_augmentation_sources=[]):
"""A generator that returns images and corresponding target class ids,
bounding box deltas, and masks.
Expand Down Expand Up @@ -1673,6 +1674,8 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No
outputs list: Usually empty in regular training. But if detection_targets
is True then the outputs list contains target class_ids, bbox deltas,
and masks.
no_augmentation_sources: (list) Optional. List of sources to be skipped for augmentation
"""
b = 0 # batch item index
image_index = -1
Expand All @@ -1698,10 +1701,18 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No

# Get GT bounding boxes and masks for image.
image_id = image_ids[image_index]
image, image_meta, gt_class_ids, gt_boxes, gt_masks = \

# If the image source is not to be augmented pass None as augmentation
if dataset.image_info[image_id]['source'] in no_augmentation_sources:
image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
load_image_gt(dataset, config, image_id, augment=augment,
augmentation=augmentation,
augmentation=None,
use_mini_mask=config.USE_MINI_MASK)
else:
image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
load_image_gt(dataset, config, image_id, augment=augment,
augmentation=augmentation,
use_mini_mask=config.USE_MINI_MASK)

# Skip images that have no instances. This can happen in cases
# where we train on a subset of classes and the image doesn't
Expand Down Expand Up @@ -2272,7 +2283,7 @@ def set_log_dir(self, model_path=None):
"*epoch*", "{epoch:04d}")

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None, custom_callbacks=[]):
augmentation=None, custom_callbacks=[], no_augmentation_sources=[]):
"""Train the model.
train_dataset, val_dataset: Training and validation Dataset objects.
learning_rate: The learning rate to train with
Expand All @@ -2299,8 +2310,10 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
imgaug.augmenters.Fliplr(0.5),
imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
])
custom_callbacks: (list) Optional. Add custom callbacks to be called
with the keras fit_generator method. Must be list of type keras.callbacks.
custom_callbacks: (list) Optional. Add custom callbacks to be called
with the keras fit_generator method. Must be list of type keras.callbacks.
no_augmentation_sources: (list) Optional. List of sources to be skipped for augmentation
"""
Expand All @@ -2323,9 +2336,11 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
# Data generators
train_generator = data_generator(train_dataset, self.config, shuffle=True,
augmentation=augmentation,
batch_size=self.config.BATCH_SIZE)
batch_size=self.config.BATCH_SIZE,
no_augmentation_sources=no_augmentation_sources)
val_generator = data_generator(val_dataset, self.config, shuffle=True,
batch_size=self.config.BATCH_SIZE)
batch_size=self.config.BATCH_SIZE,
no_augmentation_sources=no_augmentation_sources)

# Callbacks
callbacks = [
Expand Down

0 comments on commit 5202a02

Please sign in to comment.