Skip to content

Commit

Permalink
fixup 60703fd: more flexible no of workers
Browse files Browse the repository at this point in the history
  • Loading branch information
bertsky committed May 3, 2021
1 parent a02c408 commit 8c01479
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions mrcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2529,7 +2529,9 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
if os.name is 'nt':
workers = 0
else:
workers = multiprocessing.cpu_count()
#workers = multiprocessing.cpu_count()
# prevent oversubscription on clusters:
workers = max(3, self.config.BATCH_SIZE)

self.keras_model.fit_generator(
train_generator,
Expand All @@ -2540,7 +2542,7 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
validation_data=val_generator,
validation_steps=self.config.VALIDATION_STEPS,
max_queue_size=100,
workers=2, # workers
workers=workers,
use_multiprocessing=False, # True
)
self.epoch = max(self.epoch, epochs)
Expand All @@ -2556,11 +2558,20 @@ def evaluate(self, val_dataset):
self.compile(self.config.LEARNING_RATE, self.config.LEARNING_MOMENTUM)
val_generator = DataGenerator(val_dataset, self.config, shuffle=True,
batch_size=self.config.BATCH_SIZE)
# Work-around for Windows: Keras fails on Windows when using
# multiprocessing workers. See discussion here:
# https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009
if os.name is 'nt':
workers = 0
else:
#workers = multiprocessing.cpu_count()
# prevent oversubscription on clusters:
workers = max(3, self.config.BATCH_SIZE)
return self.keras_model.evaluate_generator(
val_generator,
steps=self.config.VALIDATION_STEPS,
max_queue_size=100,
workers=2, # workers
workers=workers,
use_multiprocessing=False, # True
)

Expand Down

0 comments on commit 8c01479

Please sign in to comment.