From 5202a02dfef3c85fca1fa597db46abfb16e12ad8 Mon Sep 17 00:00:00 2001 From: Nick Date: Sun, 17 Jun 2018 11:42:49 +0200 Subject: [PATCH] Add no augmentation sources Add the possibility to exclude some sources from augmentation by passing a list of sources. This is useful when you want to retrain a model having few images. --- mrcnn/model.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/mrcnn/model.py b/mrcnn/model.py index cbd490c4e3..dc6153be6a 100644 --- a/mrcnn/model.py +++ b/mrcnn/model.py @@ -1636,7 +1636,8 @@ def generate_random_rois(image_shape, count, gt_class_ids, gt_boxes): def data_generator(dataset, config, shuffle=True, augment=False, augmentation=None, - random_rois=0, batch_size=1, detection_targets=False): + random_rois=0, batch_size=1, detection_targets=False, + no_augmentation_sources=[]): """A generator that returns images and corresponding target class ids, bounding box deltas, and masks. @@ -1673,6 +1674,8 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No outputs list: Usually empty in regular training. But if detection_targets is True then the outputs list contains target class_ids, bbox deltas, and masks. + + no_augmentation_sources: (list) Optional. List of sources to be skipped for augmentation """ b = 0 # batch item index image_index = -1 @@ -1698,10 +1701,18 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No # Get GT bounding boxes and masks for image. image_id = image_ids[image_index] - image, image_meta, gt_class_ids, gt_boxes, gt_masks = \ + + # If the image source is not to be augmented pass None as augmentation + if dataset.image_info[image_id]['source'] in no_augmentation_sources: + image, image_meta, gt_class_ids, gt_boxes, gt_masks = \ load_image_gt(dataset, config, image_id, augment=augment, - augmentation=augmentation, + augmentation=None, use_mini_mask=config.USE_MINI_MASK) + else: + image, image_meta, gt_class_ids, gt_boxes, gt_masks = \ + load_image_gt(dataset, config, image_id, augment=augment, + augmentation=augmentation, + use_mini_mask=config.USE_MINI_MASK) # Skip images that have no instances. This can happen in cases # where we train on a subset of classes and the image doesn't @@ -2272,7 +2283,7 @@ def set_log_dir(self, model_path=None): "*epoch*", "{epoch:04d}") def train(self, train_dataset, val_dataset, learning_rate, epochs, layers, - augmentation=None, custom_callbacks=[]): + augmentation=None, custom_callbacks=[], no_augmentation_sources=[]): """Train the model. train_dataset, val_dataset: Training and validation Dataset objects. learning_rate: The learning rate to train with @@ -2299,8 +2310,10 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers, imgaug.augmenters.Fliplr(0.5), imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0)) ]) - custom_callbacks: (list) Optional. Add custom callbacks to be called - with the keras fit_generator method. Must be list of type keras.callbacks. + custom_callbacks: (list) Optional. Add custom callbacks to be called + with the keras fit_generator method. Must be list of type keras.callbacks. + + no_augmentation_sources: (list) Optional. List of sources to be skipped for augmentation """ @@ -2323,9 +2336,11 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers, # Data generators train_generator = data_generator(train_dataset, self.config, shuffle=True, augmentation=augmentation, - batch_size=self.config.BATCH_SIZE) + batch_size=self.config.BATCH_SIZE, + no_augmentation_sources=no_augmentation_sources) val_generator = data_generator(val_dataset, self.config, shuffle=True, - batch_size=self.config.BATCH_SIZE) + batch_size=self.config.BATCH_SIZE, + no_augmentation_sources=no_augmentation_sources) # Callbacks callbacks = [