Skip to content

Commit

Permalink
Update predict_bbox.py
Browse files Browse the repository at this point in the history
1, added function reverse_resized_rect that resize rect to fit the original image size.
2, switched order of width and height in line 46/53
  • Loading branch information
Bernardo1998 authored Sep 20, 2020
1 parent 64d012b commit d1d3eb5
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions predict_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def rect_to_bb(rect):
# return a tuple of (x, y, w, h)
return (x, y, w, h)

def reverse_resized_rect(rect,resize_ratio):
l = int(rect.left() / resize_ratio)
t = int(rect.top() / resize_ratio)
r = int(rect.right() / resize_ratio)
b = int(rect.bottom() / resize_ratio)
new_rect = dlib.rectangle(l,t,r,b)

return [l,t,r,b] , new_rect

def detect_face(image_paths, SAVE_DETECTED_AT, default_max_size=800, size = 300, padding = 0.25):
cnn_face_detector = dlib.cnn_face_detection_model_v1('dlib_models/mmod_human_face_detector.dat')
Expand All @@ -34,12 +42,16 @@ def detect_face(image_paths, SAVE_DETECTED_AT, default_max_size=800, size = 30
print('---%d/%d---' %(index, len(image_paths)))

img = dlib.load_rgb_image(image_path)
old_width, old_height, _ = img.shape

old_height, old_width, _ = img.shape
if old_width > old_height:
new_width, new_height = default_max_size, int(default_max_size * old_height / old_width)
resize_ratio = default_max_size / old_width
new_width, new_height = default_max_size, int(old_height * resize_ratio)
else:
new_width, new_height = int(default_max_size * old_height / old_width), default_max_size
img = dlib.resize_image(img, new_width, new_height )
resize_ratio = default_max_size / old_height
new_width, new_height = int(old_width * resize_ratio), default_max_size
img = dlib.resize_image(img, cols=new_width, rows=new_height)

dets = cnn_face_detector(img, 1)
num_faces = len(dets)
if num_faces == 0:
Expand All @@ -51,13 +63,15 @@ def detect_face(image_paths, SAVE_DETECTED_AT, default_max_size=800, size = 30
for detection in dets:
rect = detection.rect
faces.append(sp(img, rect))
rects.append(rect)
rect_tpl ,rect_in_origin = reverse_resized_rect(rect,resize_ratio)
rects.append(rect_in_origin)
images = dlib.get_face_chips(img, faces, size=size, padding = padding)
for idx, image in enumerate(images):
img_name = image_path.split("/")[-1]
path_sp = img_name.split(".")
face_name = os.path.join(SAVE_DETECTED_AT, path_sp[0] + "_" + "face" + str(idx) + "." + path_sp[-1])
dlib.save_image(image, face_name)
dlib.save_image(image, face_name)

return rects

def predidct_age_gender_race(save_prediction_at, bboxes, imgs_path = 'cropped_faces/'):
Expand All @@ -66,7 +80,8 @@ def predidct_age_gender_race(save_prediction_at, bboxes, imgs_path = 'cropped_fa

model_fair_7 = torchvision.models.resnet34(pretrained=True)
model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
model_fair_7.load_state_dict(torch.load('fair_face_models/res34_fair_align_multi_7_20190809.pt'))
model_fair_7.load_state_dict(torch.load('fair_face_models/fairface_alldata_20191111.pt'))
#model_fair_7.load_state_dict(torch.load('fair_face_models/res34_fair_align_multi_7_20190809.pt'))
model_fair_7 = model_fair_7.to(device)
model_fair_7.eval()

Expand Down Expand Up @@ -220,6 +235,5 @@ def ensure_dir(directory):
ensure_dir(SAVE_DETECTED_AT)
imgs = pd.read_csv(args.input_csv)['img_path']
bboxes = detect_face(imgs, SAVE_DETECTED_AT)
print(len(bboxes))
print("detected faces are saved at ", SAVE_DETECTED_AT)
predidct_age_gender_race("test_outputs.csv", bboxes, SAVE_DETECTED_AT)
predidct_age_gender_race("test_outputs.csv", bboxes, SAVE_DETECTED_AT)

0 comments on commit d1d3eb5

Please sign in to comment.