forked from MVIG-SJTU/AlphaPose
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
4,019 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
from collections import namedtuple | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from .lbs import lbs, hybrik, rotmat_to_quat, quat_to_rotmat | ||
|
||
try: | ||
import cPickle as pk | ||
except ImportError: | ||
import pickle as pk | ||
|
||
|
||
ModelOutput = namedtuple('ModelOutput', | ||
['vertices', 'joints', 'joints_from_verts', | ||
'rot_mats']) | ||
ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields) | ||
|
||
|
||
def to_tensor(array, dtype=torch.float32): | ||
if 'torch.tensor' not in str(type(array)): | ||
return torch.tensor(array, dtype=dtype) | ||
|
||
|
||
class Struct(object): | ||
def __init__(self, **kwargs): | ||
for key, val in kwargs.items(): | ||
setattr(self, key, val) | ||
|
||
|
||
def to_np(array, dtype=np.float32): | ||
if 'scipy.sparse' in str(type(array)): | ||
array = array.todense() | ||
return np.array(array, dtype=dtype) | ||
|
||
|
||
class SMPL_layer(nn.Module): | ||
NUM_JOINTS = 23 | ||
NUM_BODY_JOINTS = 23 | ||
NUM_BETAS = 10 | ||
JOINT_NAMES = [ | ||
'pelvis', 'left_hip', 'right_hip', # 2 | ||
'spine1', 'left_knee', 'right_knee', # 5 | ||
'spine2', 'left_ankle', 'right_ankle', # 8 | ||
'spine3', 'left_foot', 'right_foot', # 11 | ||
'neck', 'left_collar', 'right_collar', # 14 | ||
'jaw', # 15 | ||
'left_shoulder', 'right_shoulder', # 17 | ||
'left_elbow', 'right_elbow', # 19 | ||
'left_wrist', 'right_wrist', # 21 | ||
'left_thumb', 'right_thumb', # 23 | ||
'head', 'left_middle', 'right_middle', # 26 | ||
'left_bigtoe', 'right_bigtoe' # 28 | ||
] | ||
LEAF_NAMES = [ | ||
'head', 'left_middle', 'right_middle', 'left_bigtoe', 'right_bigtoe' | ||
] | ||
root_idx_17 = 0 | ||
root_idx_smpl = 0 | ||
|
||
def __init__(self, | ||
model_path, | ||
h36m_jregressor, | ||
gender='neutral', | ||
dtype=torch.float32, | ||
num_joints=29): | ||
''' SMPL model layers | ||
Parameters: | ||
---------- | ||
model_path: str | ||
The path to the folder or to the file where the model | ||
parameters are stored | ||
gender: str, optional | ||
Which gender to load | ||
''' | ||
super(SMPL_layer, self).__init__() | ||
|
||
self.ROOT_IDX = self.JOINT_NAMES.index('pelvis') | ||
self.LEAF_IDX = [self.JOINT_NAMES.index(name) for name in self.LEAF_NAMES] | ||
self.SPINE3_IDX = 9 | ||
|
||
with open(model_path, 'rb') as smpl_file: | ||
self.smpl_data = Struct(**pk.load(smpl_file, encoding='latin1')) | ||
|
||
self.gender = gender | ||
|
||
self.dtype = dtype | ||
|
||
self.faces = self.smpl_data.f | ||
|
||
''' Register Buffer ''' | ||
# Faces | ||
self.register_buffer('faces_tensor', | ||
to_tensor(to_np(self.smpl_data.f, dtype=np.int64), dtype=torch.long)) | ||
|
||
# The vertices of the template model, (6890, 3) | ||
self.register_buffer('v_template', | ||
to_tensor(to_np(self.smpl_data.v_template), dtype=dtype)) | ||
|
||
# The shape components | ||
# Shape blend shapes basis, (6890, 3, 10) | ||
self.register_buffer( | ||
'shapedirs', | ||
to_tensor(to_np(self.smpl_data.shapedirs), dtype=dtype)) | ||
|
||
# Pose blend shape basis: 6890 x 3 x 23*9, reshaped to 6890*3 x 23*9 | ||
num_pose_basis = self.smpl_data.posedirs.shape[-1] | ||
# 23*9 x 6890*3 | ||
posedirs = np.reshape(self.smpl_data.posedirs, [-1, num_pose_basis]).T | ||
self.register_buffer('posedirs', | ||
to_tensor(to_np(posedirs), dtype=dtype)) | ||
|
||
# Vertices to Joints location (23 + 1, 6890) | ||
self.register_buffer( | ||
'J_regressor', | ||
to_tensor(to_np(self.smpl_data.J_regressor), dtype=dtype)) | ||
# Vertices to Human3.6M Joints location (17, 6890) | ||
self.register_buffer( | ||
'J_regressor_h36m', | ||
to_tensor(to_np(h36m_jregressor), dtype=dtype)) | ||
|
||
self.num_joints = num_joints | ||
|
||
# indices of parents for each joints | ||
parents = torch.zeros(len(self.JOINT_NAMES), dtype=torch.long) | ||
parents[:(self.NUM_JOINTS + 1)] = to_tensor(to_np(self.smpl_data.kintree_table[0])).long() | ||
parents[0] = -1 | ||
# extend kinematic tree | ||
parents[24] = 15 | ||
parents[25] = 22 | ||
parents[26] = 23 | ||
parents[27] = 10 | ||
parents[28] = 11 | ||
if parents.shape[0] > self.num_joints: | ||
parents = parents[:24] | ||
|
||
self.register_buffer( | ||
'children_map', | ||
self._parents_to_children(parents)) | ||
# (24,) | ||
self.register_buffer('parents', parents) | ||
|
||
# (6890, 23 + 1) | ||
self.register_buffer('lbs_weights', | ||
to_tensor(to_np(self.smpl_data.weights), dtype=dtype)) | ||
|
||
def _parents_to_children(self, parents): | ||
children = torch.ones_like(parents) * -1 | ||
for i in range(self.num_joints): | ||
if children[parents[i]] < 0: | ||
children[parents[i]] = i | ||
for i in self.LEAF_IDX: | ||
if i < children.shape[0]: | ||
children[i] = -1 | ||
|
||
children[self.SPINE3_IDX] = -3 | ||
children[0] = 3 | ||
children[self.SPINE3_IDX] = self.JOINT_NAMES.index('neck') | ||
|
||
return children | ||
|
||
def forward(self, | ||
pose_axis_angle, | ||
betas, | ||
global_orient, | ||
transl=None, | ||
return_verts=True): | ||
''' Forward pass for the SMPL model | ||
Parameters | ||
---------- | ||
pose_axis_angle: torch.tensor, optional, shape Bx(J*3) | ||
It should be a tensor that contains joint rotations in | ||
axis-angle format. (default=None) | ||
betas: torch.tensor, optional, shape Bx10 | ||
It can used if shape parameters | ||
`betas` are predicted from some external model. | ||
(default=None) | ||
global_orient: torch.tensor, optional, shape Bx3 | ||
Global Orientations. | ||
transl: torch.tensor, optional, shape Bx3 | ||
Global Translations. | ||
return_verts: bool, optional | ||
Return the vertices. (default=True) | ||
Returns | ||
------- | ||
''' | ||
# batch_size = pose_axis_angle.shape[0] | ||
|
||
# concate root orientation with thetas | ||
if global_orient is not None: | ||
full_pose = torch.cat([global_orient, pose_axis_angle], dim=1) | ||
else: | ||
full_pose = pose_axis_angle | ||
|
||
# Translate thetas to rotation matrics | ||
pose2rot = True | ||
# vertices: (B, N, 3), joints: (B, K, 3) | ||
vertices, joints, rot_mats, joints_from_verts_h36m = lbs(betas, full_pose, self.v_template, | ||
self.shapedirs, self.posedirs, | ||
self.J_regressor, self.J_regressor_h36m, self.parents, | ||
self.lbs_weights, pose2rot=pose2rot, dtype=self.dtype) | ||
|
||
if transl is not None: | ||
# apply translations | ||
joints += transl.unsqueeze(dim=1) | ||
vertices += transl.unsqueeze(dim=1) | ||
joints_from_verts_h36m += transl.unsqueeze(dim=1) | ||
else: | ||
vertices = vertices - joints_from_verts_h36m[:, self.root_idx_17, :].unsqueeze(1).detach() | ||
joints = joints - joints[:, self.root_idx_smpl, :].unsqueeze(1).detach() | ||
joints_from_verts_h36m = joints_from_verts_h36m - joints_from_verts_h36m[:, self.root_idx_17, :].unsqueeze(1).detach() | ||
|
||
output = ModelOutput( | ||
vertices=vertices, joints=joints, rot_mats=rot_mats, joints_from_verts=joints_from_verts_h36m) | ||
return output | ||
|
||
def hybrik(self, | ||
pose_skeleton, | ||
betas, | ||
phis, | ||
global_orient, | ||
transl=None, | ||
return_verts=True, | ||
leaf_thetas=None): | ||
''' Inverse pass for the SMPL model | ||
Parameters | ||
---------- | ||
pose_skeleton: torch.tensor, optional, shape Bx(J*3) | ||
It should be a tensor that contains joint locations in | ||
(X, Y, Z) format. (default=None) | ||
betas: torch.tensor, optional, shape Bx10 | ||
It can used if shape parameters | ||
`betas` are predicted from some external model. | ||
(default=None) | ||
global_orient: torch.tensor, optional, shape Bx3 | ||
Global Orientations. | ||
transl: torch.tensor, optional, shape Bx3 | ||
Global Translations. | ||
return_verts: bool, optional | ||
Return the vertices. (default=True) | ||
Returns | ||
------- | ||
''' | ||
batch_size = pose_skeleton.shape[0] | ||
|
||
if leaf_thetas is not None: | ||
leaf_thetas = leaf_thetas.reshape(batch_size * 5, 4) | ||
leaf_thetas = quat_to_rotmat(leaf_thetas) | ||
|
||
vertices, new_joints, rot_mats, joints_from_verts = hybrik( | ||
betas, global_orient, pose_skeleton, phis, | ||
self.v_template, self.shapedirs, self.posedirs, | ||
self.J_regressor, self.J_regressor_h36m, self.parents, self.children_map, | ||
self.lbs_weights, dtype=self.dtype, train=self.training, | ||
leaf_thetas=leaf_thetas) | ||
|
||
rot_mats = rot_mats.reshape(batch_size * 24, 3, 3) | ||
rot_mats = rotmat_to_quat(rot_mats).reshape(batch_size, 24 * 4) | ||
|
||
if transl is not None: | ||
new_joints += transl.unsqueeze(dim=1) | ||
vertices += transl.unsqueeze(dim=1) | ||
joints_from_verts += transl.unsqueeze(dim=1) | ||
else: | ||
vertices = vertices - joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach() | ||
new_joints = new_joints - new_joints[:, self.root_idx_smpl, :].unsqueeze(1).detach() | ||
joints_from_verts = joints_from_verts - joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach() | ||
|
||
output = ModelOutput( | ||
vertices=vertices, joints=new_joints, rot_mats=rot_mats, joints_from_verts=joints_from_verts) | ||
return output |
Oops, something went wrong.