Skip to content

Commit

Permalink
Add custom callbacks to model training
Browse files Browse the repository at this point in the history
Add an optional parameter for calling a list of keras.callbacks to be add to the original list.
  • Loading branch information
Nick authored and waleedka committed Jul 12, 2018
1 parent 3ba867e commit 23c82fd
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion mrcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,7 +2272,7 @@ def set_log_dir(self, model_path=None):
"*epoch*", "{epoch:04d}")

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None):
augmentation=None, custom_callbacks=[]):
"""Train the model.
train_dataset, val_dataset: Training and validation Dataset objects.
learning_rate: The learning rate to train with
Expand All @@ -2299,6 +2299,10 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
imgaug.augmenters.Fliplr(0.5),
imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
])
custom_callbacks: (list) Optional. Add custom callbacks to be called
with the keras fit_generator method. Must be list of type keras.callbacks.
"""
assert self.mode == "training", "Create model in training mode."

Expand Down Expand Up @@ -2330,6 +2334,9 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
keras.callbacks.ModelCheckpoint(self.checkpoint_path,
verbose=0, save_weights_only=True),
]

# Add custom callbacks to the list
callbacks+=custom_callbacks

# Train
log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
Expand Down

0 comments on commit 23c82fd

Please sign in to comment.