forked from cunjian/pytorch_face_landmark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_batch_mtcnn.py
102 lines (94 loc) · 3.62 KB
/
test_batch_mtcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Face alignment demo
# Uses MTCNN as face detector
# Cunjian Chen ([email protected])
from __future__ import division
import argparse
import torch
import os
import cv2
import numpy as np
from common.utils import BBox,drawLandmark,drawLandmark_multiple
from models.basenet import MobileNet_GDConv
import matplotlib.pyplot as plt
from src import detect_faces
import glob
import time
parser = argparse.ArgumentParser(description='PyTorch face landmark')
# Download from https://drive.google.com/file/d/1Le5UdpMkKOTRr1sTp4lwkw8263sbgdSe/view
parser.add_argument('-c', '--checkpoint', default='checkpoint/mobilenet_224_model_best_gdconv_external.pth.tar', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
args = parser.parse_args()
mean = np.asarray([ 0.485, 0.456, 0.406 ])
std = np.asarray([ 0.229, 0.224, 0.225 ])
if torch.cuda.is_available():
map_location=lambda storage, loc: storage.cuda()
else:
map_location='cpu'
def load_model():
model = MobileNet_GDConv(136)
model = torch.nn.DataParallel(model)
checkpoint = torch.load(args.checkpoint, map_location=map_location)
model.load_state_dict(checkpoint['state_dict'])
return model
if __name__ == '__main__':
out_size = 224
model = load_model()
model = model.eval()
filenames=glob.glob("samples/12--Group/*.jpg")
for imgname in filenames:
print(imgname)
img = cv2.imread(imgname)
height,width,_=img.shape
# perform face detection using MTCNN
from PIL import Image
image = Image.open(imgname)
faces, landmarks = detect_faces(image)
ratio=0
if len(faces)==0:
print('NO face is detected!')
continue
for k, face in enumerate(faces):
x1=face[0]
y1=face[1]
x2=face[2]
y2=face[3]
w = x2 - x1 + 1
h = y2 - y1 + 1
size = int(min([w, h])*1.2)
cx = x1 + w//2
cy = y1 + h//2
x1 = cx - size//2
x2 = x1 + size
y1 = cy - size//2
y2 = y1 + size
dx = max(0, -x1)
dy = max(0, -y1)
x1 = max(0, x1)
y1 = max(0, y1)
edx = max(0, x2 - width)
edy = max(0, y2 - height)
x2 = min(width, x2)
y2 = min(height, y2)
new_bbox = list(map(int, [x1, x2, y1, y2]))
new_bbox = BBox(new_bbox)
cropped=img[new_bbox.top:new_bbox.bottom,new_bbox.left:new_bbox.right]
if (dx > 0 or dy > 0 or edx > 0 or edy > 0):
cropped = cv2.copyMakeBorder(cropped, int(dy), int(edy), int(dx), int(edx), cv2.BORDER_CONSTANT, 0)
cropped_face = cv2.resize(cropped, (out_size, out_size))
if cropped_face.shape[0]<=0 or cropped_face.shape[1]<=0:
continue
test_face = cropped_face.copy()
test_face = test_face/255.0
test_face = (test_face-mean)/std
test_face = test_face.transpose((2, 0, 1))
test_face = test_face.reshape((1,) + test_face.shape)
input = torch.from_numpy(test_face).float()
input= torch.autograd.Variable(input)
start = time.time()
landmark = model(input).cpu().data.numpy()
end = time.time()
print('Time: {:.6f}s.'.format(end - start))
landmark = landmark.reshape(-1,2)
landmark = new_bbox.reprojectLandmark(landmark)
img = drawLandmark_multiple(img, new_bbox, landmark)
cv2.imwrite(os.path.join('results',os.path.basename(imgname)),img)