Skip to content

Commit

Permalink
Raise clear error if last training weights are not foundIf using the …
Browse files Browse the repository at this point in the history
…--weights=last (or --model=last) to resume trainingbut the weights are not found now it raises a clear error message.
  • Loading branch information
waleedka committed Jun 6, 2018
1 parent a688a66 commit cbff80f
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 16 deletions.
14 changes: 9 additions & 5 deletions mrcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,26 +2062,30 @@ def find_last(self):
"""Finds the last checkpoint file of the last trained model in the
model directory.
Returns:
log_dir: The directory where events and weights are saved
checkpoint_path: the path to the last checkpoint file
The path of the last checkpoint file
"""
# Get directory names. Each directory corresponds to a model
dir_names = next(os.walk(self.model_dir))[1]
key = self.config.NAME.lower()
dir_names = filter(lambda f: f.startswith(key), dir_names)
dir_names = sorted(dir_names)
if not dir_names:
return None, None
import errno
raise FileNotFoundError(
errno.ENOENT,
"Could not find model directory under {}".format(self.model_dir))
# Pick last directory
dir_name = os.path.join(self.model_dir, dir_names[-1])
# Find the last checkpoint
checkpoints = next(os.walk(dir_name))[2]
checkpoints = filter(lambda f: f.startswith("mask_rcnn"), checkpoints)
checkpoints = sorted(checkpoints)
if not checkpoints:
return dir_name, None
import errno
raise FileNotFoundError(
errno.ENOENT, "Could not find weight files in {}".format(dir_name))
checkpoint = os.path.join(dir_name, checkpoints[-1])
return dir_name, checkpoint
return checkpoint

def load_weights(self, filepath, by_name=False, exclude=None):
"""Modified version of the correspoding Keras function with
Expand Down
2 changes: 1 addition & 1 deletion samples/balloon/balloon.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class InferenceConfig(BalloonConfig):
utils.download_trained_weights(weights_path)
elif args.weights.lower() == "last":
# Find last trained weights
weights_path = model.find_last()[1]
weights_path = model.find_last()
elif args.weights.lower() == "imagenet":
# Start from ImageNet trained weights
weights_path = model.get_imagenet_weights()
Expand Down
2 changes: 1 addition & 1 deletion samples/balloon/inspect_balloon_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@
"# weights_path = \"/path/to/mask_rcnn_balloon.h5\"\n",
"\n",
"# Or, load the last model you trained\n",
"weights_path = model.find_last()[1]\n",
"weights_path = model.find_last()\n",
"\n",
"# Load weights\n",
"print(\"Loading weights \", weights_path)\n",
Expand Down
2 changes: 1 addition & 1 deletion samples/coco/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ class InferenceConfig(CocoConfig):
model_path = COCO_MODEL_PATH
elif args.model.lower() == "last":
# Find last trained weights
model_path = model.find_last()[1]
model_path = model.find_last()
elif args.model.lower() == "imagenet":
# Start from ImageNet trained weights
model_path = model.get_imagenet_weights()
Expand Down
2 changes: 1 addition & 1 deletion samples/coco/inspect_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@
"elif config.NAME == \"coco\":\n",
" weights_path = COCO_MODEL_PATH\n",
"# Or, uncomment to load the last model you trained\n",
"# weights_path = model.find_last()[1]\n",
"# weights_path = model.find_last()\n",
"\n",
"# Load weights\n",
"print(\"Loading weights \", weights_path)\n",
Expand Down
2 changes: 1 addition & 1 deletion samples/coco/inspect_weights.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
"elif config.NAME == \"coco\":\n",
" weights_path = COCO_MODEL_PATH\n",
"# Or, uncomment to load the last model you trained\n",
"# weights_path = model.find_last()[1]\n",
"# weights_path = model.find_last()\n",
"\n",
"# Load weights\n",
"print(\"Loading weights \", weights_path)\n",
Expand Down
2 changes: 1 addition & 1 deletion samples/nucleus/inspect_nucleus_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@
"# weights_path = \"/path/to/mask_rcnn_nucleus.h5\"\n",
"\n",
"# Or, load the last model you trained\n",
"weights_path = model.find_last()[1]\n",
"weights_path = model.find_last()\n",
"\n",
"# Load weights\n",
"print(\"Loading weights \", weights_path)\n",
Expand Down
2 changes: 1 addition & 1 deletion samples/nucleus/nucleus.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def detect(model, dataset_dir, subset):
utils.download_trained_weights(weights_path)
elif args.weights.lower() == "last":
# Find last trained weights
weights_path = model.find_last()[1]
weights_path = model.find_last()
elif args.weights.lower() == "imagenet":
# Start from ImageNet trained weights
weights_path = model.get_imagenet_weights()
Expand Down
7 changes: 3 additions & 4 deletions samples/shapes/train_shapes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@
" \"mrcnn_bbox\", \"mrcnn_mask\"])\n",
"elif init_with == \"last\":\n",
" # Load the last model you trained and continue training\n",
" model.load_weights(model.find_last()[1], by_name=True)"
" model.load_weights(model.find_last(), by_name=True)"
]
},
{
Expand Down Expand Up @@ -875,10 +875,9 @@
"# Get path to saved weights\n",
"# Either set a specific path or find last trained weights\n",
"# model_path = os.path.join(ROOT_DIR, \".h5 file name here\")\n",
"model_path = model.find_last()[1]\n",
"model_path = model.find_last()\n",
"\n",
"# Load trained weights (fill in path to trained weights here)\n",
"assert model_path != \"\", \"Provide path to trained weights\"\n",
"# Load trained weights\n",
"print(\"Loading weights from \", model_path)\n",
"model.load_weights(model_path, by_name=True)"
]
Expand Down

0 comments on commit cbff80f

Please sign in to comment.