Skip to content

Commit

Permalink
DataGenerator: shuffle on first 'epoch', too
Browse files Browse the repository at this point in the history
  • Loading branch information
bertsky committed May 17, 2021
1 parent 22d5a1f commit 90e3f01
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions mrcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,9 +1333,14 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
# Active classes
# Different datasets have different classes, so track the
# classes supported in the dataset of this image.
source = dataset.image_info[image_id]["source"]
source_class_ids = dataset.source_class_ids[source]
active_class_ids = np.zeros([config.NUM_CLASSES], dtype=np.int32)
source_class_ids = dataset.source_class_ids[dataset.image_info[image_id]["source"]]
active_class_ids[source_class_ids] = 1
assert np.all(active_class_ids[class_ids]), \
"Image {} ('{}') has annotations for classes not known for source '{}': {}".format(
image_id, dataset.image_reference(image_id), source,
set(class_ids).difference(source_class_ids))

# Resize masks to smaller size to reduce memory usage
if use_mini_mask:
Expand Down Expand Up @@ -1746,7 +1751,7 @@ def __init__(self, dataset, config, shuffle=True, augment=False, augmentation=No
random_rois=0, batch_size=1, detection_targets=False,
no_augmentation_sources=None):

self.image_ids = np.copy(dataset.image_ids)
self.image_ids = np.copy(dataset.image_ids) # shuffled
self.dataset = dataset
self.config = config
self.error_count = 0
Expand All @@ -1767,8 +1772,7 @@ def __init__(self, dataset, config, shuffle=True, augment=False, augmentation=No
self.batch_size = batch_size
self.detection_targets = detection_targets
self.no_augmentation_sources = no_augmentation_sources or []


self.on_epoch_end()

def __len__(self):
return int(np.ceil(len(self.image_ids) / float(self.batch_size)))
Expand Down Expand Up @@ -1873,7 +1877,7 @@ def data_generator(self,image_ids):
if self.error_count > 5:
raise

# Batch full?
# Batch not empty?
if b > 0:
inputs = [batch_images, batch_image_meta, batch_rpn_match, batch_rpn_bbox,
batch_gt_class_ids, batch_gt_boxes, batch_gt_masks]
Expand Down Expand Up @@ -1974,7 +1978,7 @@ def data_generator(self,image_ids):
if self.error_count > 5:
raise

# Batch full?
# Batch not empty?
if b > 0:
return [batch_images, batch_image_meta, batch_anchors]

Expand Down Expand Up @@ -2583,11 +2587,11 @@ def detect(self, images, verbose=0, active_class_ids=None, callbacks=None):
Boolean matrix [images, classes].
Returns an in-order list of dicts, one dict per image. The dict contains:
image_id: Integer image identifier as provided by the generator's image_metas output.
rois: [N, (y1, x1, y2, x2)] detection bounding boxes
class_ids: [N] int class IDs
scores: [N] float probability scores for the class IDs
masks: [H, W, N] instance binary masks
- 'image_id': Integer image identifier as provided by the generator's image_metas output.
- 'rois': [N, (y1, x1, y2, x2)] detection bounding boxes
- 'class_ids': [N] int class IDs
- 'scores': [N] float probability scores for the class IDs
- 'masks': [H, W, N] instance binary masks
"""
assert self.mode == "inference", "Create model in inference mode."
if verbose:
Expand Down Expand Up @@ -2639,11 +2643,11 @@ def detect_molded(self, molded_images, image_metas, verbose=0, callbacks=None):
image_metas: image meta data, also returned by load_image_gt()
Returns an in-order list of dicts, one dict per image. The dict contains:
image_id: Integer image identifier as provided by the generator's image_metas output.
rois: [N, (y1, x1, y2, x2)] detection bounding boxes
class_ids: [N] int class IDs
scores: [N] float probability scores for the class IDs
masks: [H, W, N] instance binary masks
- 'image_id': Integer image identifier as provided by the generator's image_metas output.
- 'rois': [N, (y1, x1, y2, x2)] detection bounding boxes
- 'class_ids': [N] int class IDs
- 'scores': [N] float probability scores for the class IDs
- 'masks': [H, W, N] instance binary masks
"""
assert self.mode == "inference", "Create model in inference mode."
if verbose:
Expand Down Expand Up @@ -2690,11 +2694,11 @@ def detect_generator(self, batches, callbacks=None, max_queue_size=10, workers=1
(See keras.Model.predict_generator() for other arguments.)
Returns an in-order list of dicts, one dict per image. The dict contains:
image_id: Integer image identifier as provided by the generator's image_metas output.
rois: [N, (y1, x1, y2, x2)] detection bounding boxes
class_ids: [N] int class IDs
scores: [N] float probability scores for the class IDs
masks: [H, W, N] instance binary masks
- 'image_id': Integer image identifier as provided by the generator's image_metas output.
- 'rois': [N, (y1, x1, y2, x2)] detection bounding boxes
- 'class_ids': [N] int class IDs
- 'scores': [N] float probability scores for the class IDs
- 'masks': [H, W, N] instance binary masks
"""
assert self.mode == "inference", "Create model in inference mode."
assert KE.training_utils.is_sequence(batches), \
Expand Down

0 comments on commit 90e3f01

Please sign in to comment.