38
38
FLAMEOutput ,
39
39
find_joint_kin_chain )
40
40
from .vertex_joint_selector import VertexJointSelector
41
+ from collections import namedtuple
41
42
43
+ TensorOutput = namedtuple ('TensorOutput' ,
44
+ ['vertices' , 'joints' , 'betas' , 'expression' , 'global_orient' , 'body_pose' , 'left_hand_pose' ,
45
+ 'right_hand_pose' , 'jaw_pose' , 'transl' , 'full_pose' ])
42
46
43
47
class SMPL (nn .Module ):
44
48
@@ -48,7 +52,7 @@ class SMPL(nn.Module):
48
52
49
53
def __init__ (
50
54
self , model_path : str ,
51
- kid_template_path : str = '' ,
55
+ kid_template_path : str = '' ,
52
56
data_struct : Optional [Struct ] = None ,
53
57
create_betas : bool = True ,
54
58
betas : Optional [Tensor ] = None ,
@@ -143,7 +147,9 @@ def __init__(
143
147
shapedirs = data_struct .shapedirs
144
148
if (shapedirs .shape [- 1 ] < self .SHAPE_SPACE_DIM ):
145
149
print (f'WARNING: You are using a { self .name ()} model, with only'
146
- f' { shapedirs .shape [- 1 ]} shape coefficients.' )
150
+ f' { shapedirs .shape [- 1 ]} shape coefficients.\n '
151
+ f'num_betas={ num_betas } , shapedirs.shape={ shapedirs .shape } , '
152
+ f'self.SHAPE_SPACE_DIM={ self .SHAPE_SPACE_DIM } ' )
147
153
num_betas = min (num_betas , shapedirs .shape [- 1 ])
148
154
else :
149
155
num_betas = min (num_betas , self .SHAPE_SPACE_DIM )
@@ -901,7 +907,7 @@ class SMPLX(SMPLH):
901
907
902
908
def __init__ (
903
909
self , model_path : str ,
904
- kid_template_path : str = '' ,
910
+ kid_template_path : str = '' ,
905
911
num_expression_coeffs : int = 10 ,
906
912
create_expression : bool = True ,
907
913
expression : Optional [Tensor ] = None ,
@@ -1128,7 +1134,7 @@ def forward(
1128
1134
pose2rot : bool = True ,
1129
1135
return_shaped : bool = True ,
1130
1136
** kwargs
1131
- ) -> SMPLXOutput :
1137
+ ) -> TensorOutput :
1132
1138
'''
1133
1139
Forward pass for the SMPLX model
1134
1140
@@ -1276,7 +1282,9 @@ def forward(
1276
1282
v_shaped = None
1277
1283
if return_shaped :
1278
1284
v_shaped = self .v_template + blend_shapes (betas , self .shapedirs )
1279
- output = SMPLXOutput (vertices = vertices if return_verts else None ,
1285
+ else :
1286
+ v_shaped = Tensor (0 )
1287
+ output = TensorOutput (vertices = vertices if return_verts else None ,
1280
1288
joints = joints ,
1281
1289
betas = betas ,
1282
1290
expression = expression ,
@@ -1324,9 +1332,9 @@ def forward(
1324
1332
leye_pose : Optional [Tensor ] = None ,
1325
1333
reye_pose : Optional [Tensor ] = None ,
1326
1334
return_verts : bool = True ,
1327
- return_full_pose : bool = False ,
1335
+ return_full_pose : bool = True ,
1328
1336
** kwargs
1329
- ) -> SMPLXOutput :
1337
+ ) -> TensorOutput :
1330
1338
'''
1331
1339
Forward pass for the SMPLX model
1332
1340
@@ -1475,7 +1483,7 @@ def forward(
1475
1483
joints += transl .unsqueeze (dim = 1 )
1476
1484
vertices += transl .unsqueeze (dim = 1 )
1477
1485
1478
- output = SMPLXOutput (vertices = vertices if return_verts else None ,
1486
+ output = TensorOutput (vertices = vertices if return_verts else Tensor ( 0 ) ,
1479
1487
joints = joints ,
1480
1488
betas = betas ,
1481
1489
expression = expression ,
@@ -1484,8 +1492,9 @@ def forward(
1484
1492
left_hand_pose = left_hand_pose ,
1485
1493
right_hand_pose = right_hand_pose ,
1486
1494
jaw_pose = jaw_pose ,
1487
- transl = transl ,
1488
- full_pose = full_pose if return_full_pose else None )
1495
+ transl = transl if transl != None else Tensor (0 ),
1496
+ full_pose = full_pose if return_full_pose else Tensor (0 ))
1497
+
1489
1498
return output
1490
1499
1491
1500
0 commit comments