Skip to content

Commit

Permalink
Add --logs argument to coco.py
Browse files Browse the repository at this point in the history
  • Loading branch information
waleedka committed Nov 14, 2017
1 parent 3361838 commit 059ce9d
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
# Path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
# Directory to save logs and model checkpoints, if not provided
# through the command line argument --logs
DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs")


############################################################
Expand Down Expand Up @@ -321,10 +322,15 @@ def evaluate_coco(model, dataset, coco, eval_type="bbox", limit=0):
parser.add_argument('--model', required=True,
metavar="/path/to/weights.h5",
help="Path to weights .h5 file or 'coco'")
parser.add_argument('--logs', required=False,
default=DEFAULT_LOGS_DIR,
metavar="/path/to/logs/",
help='Directory to save logs and checkpoints. Defaults to logs/')
args = parser.parse_args()
print("Command: ", args.command)
print("Model: ", args.model)
print("Dataset: ", args.dataset)
print("Logs: ", args.logs)

# Configurations
if args.command == "train":
Expand All @@ -341,10 +347,10 @@ class InferenceConfig(CocoConfig):
# Create model
if args.command == "train":
model = modellib.MaskRCNN(mode="training", config=config,
model_dir=MODEL_DIR)
model_dir=args.logs)
else:
model = modellib.MaskRCNN(mode="inference", config=config,
model_dir=MODEL_DIR)
model_dir=args.logs)

# Select weights file to load
if args.model.lower() == "coco":
Expand Down

0 comments on commit 059ce9d

Please sign in to comment.