-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsmpl.py
156 lines (129 loc) · 4.39 KB
/
smpl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# From https://github.com/KosukeFukazawa/smpl2bvh/blob/main/smpl2bvh.py
from pathlib import Path
import pickle
import numpy as np
import smplx
# import torch
# from .Animation import Animation
# from .Quaternions import Quaternions
from .smpl_utils.utils import quat
# from .transforms import euler2mat, mat2quat
SMPL_JOINTS_NAMES = [
"Pelvis",
"L_Hip",
"R_Hip",
"Spine1",
"L_Knee",
"R_Knee",
"Spine2",
"L_Ankle",
"R_Ankle",
"Spine3",
"L_Foot",
"R_Foot",
"Neck",
"L_Collar",
"R_Collar",
"Head",
"L_Shoulder",
"R_Shoulder",
"L_Elbow",
"R_Elbow",
"L_Wrist",
"R_Wrist",
"L_Hand",
"R_Hand",
]
def load_smpl(smpl_file):
"""Open animation in the SMPL format contained in a pickle or numpy data file.
Args:
smpl_file (str): Path to file
Raises:
ValueError: If the filename does not end with pkl or npz.
Returns:
smpl_dict: Dictionary with keys 'smpl_poses', 'smpl_trans' and 'smpl_scaling'
as defined by the SMPL paper.
"""
if smpl_file.endswith(".npz"):
smpl_file = np.load(smpl_file)
rots = np.squeeze(smpl_file["poses"], axis=0) # (N, 24, 3)
trans = np.squeeze(smpl_file["trans"], axis=0) # (N, 3)
elif smpl_file.endswith(".pkl"):
with open(smpl_file, "rb") as f:
smpl_file = pickle.load(f)
rots = smpl_file["smpl_poses"] # (N, 72)
rots = rots.reshape(rots.shape[0], -1, 3) # (N, 24, 3)
if "smpl_scaling" in smpl_file.keys():
scaling = smpl_file["smpl_scaling"] # (1,)
else:
scaling = (100,)
print("WARNING: No scaling found in the file, defaults to 100.")
trans = smpl_file["smpl_trans"] # (N, 3)
else:
raise ValueError("This file type is not supported!")
smpl_dict = {"smpl_poses": rots, "smpl_trans": trans, "smpl_scaling": scaling}
return smpl_dict
def smpl_to_bvh_data(smpl_dict, gender="NEUTRAL", frametime=1 / 60):
model = smplx.create(
model_path=Path(__file__).parent / "smpl_utils/data/smpl/",
model_type="smpl",
gender=gender,
batch_size=1,
)
parents = model.parents.detach().cpu().numpy()
rest = model()
rest_pose = rest.joints.detach().cpu().numpy().squeeze()[:24, :]
root_offset = rest_pose[0]
offsets = rest_pose - rest_pose[parents]
offsets[0] = root_offset
offsets *= 100
if "smpl_scaling" in smpl_dict.keys():
scaling = smpl_dict["smpl_scaling"]
else:
scaling = 100
rots = smpl_dict["smpl_poses"]
rots = rots.reshape(rots.shape[0], -1, 3) # (N, 24, 3)
trans = smpl_dict["smpl_trans"] # (N, 3)
trans /= scaling
# to quaternion
rots = quat.from_axis_angle(rots)
order = "yzx"
pos = offsets[None].repeat(len(rots), axis=0)
positions = pos.copy()
positions[:, 0] += trans * 100
rotations = np.degrees(quat.to_euler(rots, order=order))
bvh_data = {
"rotations": rotations,
"positions": positions / 100, # We want the results in meter convention
"offsets": offsets / 100,
"parents": parents,
"names": SMPL_JOINTS_NAMES,
"order": order,
"frametime": frametime,
}
return bvh_data
def bvh_data_to_smpl(bvh_data):
# First, make sure the bvh_data is in the same order as SMPL format expects
# Create a mapping from the current names to the SMPL_JOINTS_NAMES
name_to_index = {name: i for i, name in enumerate(bvh_data["names"])}
# smpl_to_index = {name: i for i, name in enumerate(SMPL_JOINTS_NAMES)}
# Create a reordering index array
reorder_index = [name_to_index[name] for name in SMPL_JOINTS_NAMES]
# Extract BVH data
rotations = bvh_data["rotations"][:, reorder_index, :]
positions = bvh_data["positions"][:, reorder_index, :]
# Convert rotations
rotations = np.radians(rotations)
rotations = quat.from_euler(rotations, order=bvh_data["order"])
rotations = quat.to_axis_angle(rotations)
# Reshape rotations to match SMPL format
rotations = rotations.reshape(rotations.shape[0], -1)
# Extract root translation and scale it back
trans = positions[:, 0] # - offsets[0][None]
# Prepare SMPL dictionary
smpl_dict = {
"smpl_poses": rotations,
"smpl_trans": trans,
"smpl_scaling": np.array([100]),
}
return smpl_dict