Skip to content

Commit

Permalink
add options to demo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yusuke-a-uchida committed Jul 21, 2017
1 parent 0d8faa2 commit cfa865a
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@
import cv2
import dlib
import numpy as np
import argparse
from wide_resnet import WideResNet


def get_args():
parser = argparse.ArgumentParser(description="This script trains the CNN model for age and gender estimation.")
parser.add_argument("--weight_file", type=str, default=None,
help="path to weight file (e.g. weights.18-4.06.hdf5)")
parser.add_argument("--depth", type=int, default=16,
help="depth of network")
parser.add_argument("--width", type=int, default=8,
help="width of network")
args = parser.parse_args()
return args


def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX,
font_scale=1, thickness=2):
size = cv2.getTextSize(label, font, font_scale, thickness)[0]
Expand All @@ -14,13 +27,21 @@ def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX,


def main():
args = get_args()
depth = args.depth
k = args.width
weight_file = args.weight_file

if not weight_file:
weight_file = os.path.join("pretrained_models", "weights.18-4.06.hdf5")

# for face detection
detector = dlib.get_frontal_face_detector()

# load model and weights
img_size = 64
model = WideResNet(img_size, depth=16, k=8)()
model.load_weights(os.path.join("pretrained_models", "weights.18-4.06.hdf5"))
model = WideResNet(img_size, depth=depth, k=k)()
model.load_weights(weight_file)

# capture video
cap = cv2.VideoCapture(0)
Expand Down

0 comments on commit cfa865a

Please sign in to comment.