Skip to content

Commit

Permalink
detect hand
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzzone committed Nov 14, 2018
1 parent 05d4b92 commit 6e3952c
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 27 deletions.
34 changes: 34 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys
sys.path.insert(0, 'python')
import cv2
import model
import util
from hand import Hand
from body import Body
import matplotlib.pyplot as plt
import copy

body_estimation = Body('model/body_pose_model.pth')
hand_estimation = Hand('model/hand_pose_model.pth')

test_image = 'images/demo.jpg'
oriImg = cv2.imread(test_image) # B,G,R order
candidate, subset = body_estimation(oriImg)
# canvas = util.draw_bodypose(oriImg, candidate, subset)
canvas = copy.deepcopy(oriImg)
# detect hand
hands_list = util.handDetect(candidate, subset, oriImg)

for x, y, w, is_left in hands_list:
cv2.rectangle(canvas, (x, y), (x+w, y+w), (0, 255, 0), 2, lineType=cv2.LINE_AA)
cv2.putText(canvas, 'left' if is_left else 'right', (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

if is_left:
plt.imshow(oriImg[y:y+w, x:x+w, :][:, :, [2, 1, 0]])
plt.show()
# peaks = hand_estimation(oriImg[y:y+w, x:x+w, :])
# canvas = util.draw_handpose(canvas, peaks, True)

plt.imshow(canvas[:, :, [2, 1, 0]])
plt.show()
# cv2.imwrite('t.jpg', canvas)
Binary file added images/demo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/detect_hand_preview.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
224 changes: 224 additions & 0 deletions notebooks/detectHand.ipynb

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions python/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,11 @@ def __call__(self, oriImg):
return candidate, subset

if __name__ == "__main__":
body_estimation = Body('model/body_pose_model.pth')
body_estimation = Body('../model/body_pose_model.pth')

test_image = 'images/ski.jpg'
test_image = '../images/ski.jpg'
oriImg = cv2.imread(test_image) # B,G,R order
candidate, subset = body_estimation(oriImg)
util.draw_bodypose(oriImg, candidate, subset)
canvas = util.draw_bodypose(oriImg, candidate, subset)
plt.imshow(canvas[:, :, [2, 1, 0]])
plt.show()
15 changes: 9 additions & 6 deletions python/hand.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,24 @@ def __call__(self, oriImg):
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
# 全部小于阈值
if np.sum(binary) == 0:
all_peaks.append(-1)
all_peaks.append([0, 0])
continue
label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
label_img[label_img != max_index] = 0
map_ori[label_img == 0] = 0

y, x = util.npmax(map_ori)
all_peaks.append((x, y))
return all_peaks
all_peaks.append([x, y])
return np.array(all_peaks)

if __name__ == "__main__":
hand_estimation = Hand('../model/body_pose_model.pth')
hand_estimation = Hand('../model/hand_pose_model.pth')

test_image = '../images/hand.jpg'
# test_image = '../images/hand.jpg'
test_image = '/Users/hzzone/Desktop/1.png'
oriImg = cv2.imread(test_image) # B,G,R order
peaks = hand_estimation(oriImg)
util.draw_handpose(oriImg, peaks)
canvas = util.draw_handpose(oriImg, peaks, True)
cv2.imshow('', canvas)
cv2.waitKey(0)
4 changes: 2 additions & 2 deletions python/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class bodypose_model(nn.Module):
def __init__(self):
super(bodypose_model, self).__init__()

# these layers has no relu layer
# these layers have no relu layer
no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
Expand Down Expand Up @@ -136,7 +136,7 @@ class handpose_model(nn.Module):
def __init__(self):
super(handpose_model, self).__init__()

# these layers has no relu layer
# these layers have no relu layer
no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
# stage 1
Expand Down
96 changes: 80 additions & 16 deletions python/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import math
import cv2
import matplotlib.pyplot as plt
import matplotlib


def padRightDownCorner(img, stride, padValue):
Expand Down Expand Up @@ -49,7 +47,6 @@ def draw_bodypose(canvas, candidate, subset):
index = int(subset[n][i])
if index == -1:
continue
print(candidate[index][0:2])
x, y = candidate[index][0:2]
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
for i in range(17):
Expand All @@ -68,24 +65,91 @@ def draw_bodypose(canvas, candidate, subset):
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
# plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
plt.imshow(canvas[:, :, [2, 1, 0]])
# plt.imshow(canvas[:, :, [2, 1, 0]])
return canvas

def draw_handpose(canvas, peaks):
def draw_handpose(canvas, peaks, show_number=False, initial_x=0, initial_y=0):
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
colors = [[100, 100, 100], [100, 0, 0], [150, 0, 0], \
[200, 0, 0], [255, 0, 0], [100, 100, 0], [150, 150, 0], [200, 200, 0], \
[255, 255, 0], [0, 100, 50], [0, 150, 75], [0, 200, 100], [0, 255, 125], \
[0, 50, 100], [0, 75, 150], [0, 100, 200], \
[0, 125, 255], [100, 0, 100], [150, 0, 150], \
[200, 0, 200], [255, 0, 255]]

plt.imshow(canvas[:, :, [2, 1, 0]])

for i, (x, y) in enumerate(peaks):
plt.plot(x, y, 'r.')
plt.text(x, y, str(i))
for ie, e in enumerate(edges):
rgb = matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
x1, y1 = peaks[e[0]]
x2, y2 = peaks[e[1]]
plt.plot([x1, x2], [y1, y2], color=rgb)
plt.axis('off')
plt.show()
if np.sum(np.all(peaks[e], axis=1)==0)==0:
cv2.line(canvas, tuple(peaks[e[0]]), tuple(peaks[e[1]]), colors[ie], thickness=1)
for i, keyponit in enumerate(peaks):
cv2.circle(canvas, tuple(keyponit), 2, (0, 0, 255), thickness=-1)
if show_number:
cv2.putText(canvas, str(i), tuple(keyponit), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA)
return canvas

# detect hand according to body pose keypoints
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
def handDetect(candidate, subset, oriImg):
# right hand: wrist 4, elbow 3, shoulder 2
# left hand: wrist 7, elbow 6, shoulder 5
ratioWristElbow = 0.33
detect_result = []
image_height, image_width = oriImg.shape[0:2]
for person in subset.astype(int):
# if any of three not detected
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
if not (has_left or has_right):
continue
hands = []
#left hand
if has_left:
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
x1, y1 = candidate[left_shoulder_index][:2]
x2, y2 = candidate[left_elbow_index][:2]
x3, y3 = candidate[left_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, True])
# right hand
if has_right:
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
x1, y1 = candidate[right_shoulder_index][:2]
x2, y2 = candidate[right_elbow_index][:2]
x3, y3 = candidate[right_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, False])

for x1, y1, x2, y2, x3, y3, is_left in hands:
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
x = x3 + ratioWristElbow * (x3 - x2)
y = y3 + ratioWristElbow * (y3 - y2)
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
# x-y refers to the center --> offset to topLeft point
# handRectangle.x -= handRectangle.width / 2.f;
# handRectangle.y -= handRectangle.height / 2.f;
x -= width / 2
y -= width / 2 # width = height
# overflow the image
if x < 0: x = 0
if y < 0: y = 0
width1 = width
width2 = width
if x + width > image_width: width1 = image_width - x
if y + width > image_height: width2 = image_height - y
width = min(width1, width2)
detect_result.append([int(x), int(y), int(width), is_left])

'''
return value: [[x, y, w, True if left hand else False]].
width=height since the network require squared input.
x, y is the coordinate of top left
'''
return detect_result

# get max index of 2d array
def npmax(array):
Expand Down

0 comments on commit 6e3952c

Please sign in to comment.