Skip to content

Commit 2be5a9f

Browse files
committed
Fix shape in SMPL+H and batch size in layer modules
1 parent e821375 commit 2be5a9f

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
2828
AUTHOR = 'Vassilis Choutas'
2929
REQUIRES_PYTHON = '>=3.6.0'
30-
VERSION = '0.1.22'
30+
VERSION = '0.1.23'
3131

3232
here = os.path.abspath(os.path.dirname(__file__))
3333

smplx/body_models.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -431,13 +431,16 @@ def forward(
431431
Returns
432432
-------
433433
'''
434+
model_vars = [betas, global_orient, body_pose, transl]
435+
batch_size = 1
436+
for var in model_vars:
437+
if var is None:
438+
continue
439+
batch_size = max(batch_size, len(var))
434440
device, dtype = self.shapedirs.device, self.shapedirs.dtype
435441
if global_orient is None:
436-
batch_size = 1
437442
global_orient = torch.eye(3, device=device, dtype=dtype).view(
438443
1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
439-
else:
440-
batch_size = global_orient.shape[0]
441444
if body_pose is None:
442445
body_pose = torch.eye(3, device=device, dtype=dtype).view(
443446
1, 1, 3, 3).expand(
@@ -675,6 +678,7 @@ def forward(
675678
) -> SMPLHOutput:
676679
'''
677680
'''
681+
678682
# If no shape and pose parameters are passed along, then use the
679683
# ones from the module
680684
global_orient = (global_orient if global_orient is not None else
@@ -702,7 +706,7 @@ def forward(
702706
right_hand_pose], dim=1)
703707
full_pose += self.pose_mean
704708

705-
vertices, joints = lbs(self.betas, full_pose, self.v_template,
709+
vertices, joints = lbs(betas, full_pose, self.v_template,
706710
self.shapedirs, self.posedirs,
707711
self.J_regressor, self.parents,
708712
self.lbs_weights, pose2rot=pose2rot)
@@ -760,13 +764,17 @@ def forward(
760764
) -> SMPLHOutput:
761765
'''
762766
'''
767+
model_vars = [betas, global_orient, body_pose, transl, left_hand_pose,
768+
right_hand_pose]
769+
batch_size = 1
770+
for var in model_vars:
771+
if var is None:
772+
continue
773+
batch_size = max(batch_size, len(var))
763774
device, dtype = self.shapedirs.device, self.shapedirs.dtype
764775
if global_orient is None:
765-
batch_size = 1
766776
global_orient = torch.eye(3, device=device, dtype=dtype).view(
767777
1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
768-
else:
769-
batch_size = global_orient.shape[0]
770778
if body_pose is None:
771779
body_pose = torch.eye(3, device=device, dtype=dtype).view(
772780
1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous()
@@ -1300,12 +1308,17 @@ def forward(
13001308
'''
13011309
device, dtype = self.shapedirs.device, self.shapedirs.dtype
13021310

1311+
model_vars = [betas, global_orient, body_pose, transl,
1312+
expression, left_hand_pose, right_hand_pose, jaw_pose]
1313+
batch_size = 1
1314+
for var in model_vars:
1315+
if var is None:
1316+
continue
1317+
batch_size = max(batch_size, len(var))
1318+
13031319
if global_orient is None:
1304-
batch_size = 1
13051320
global_orient = torch.eye(3, device=device, dtype=dtype).view(
13061321
1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
1307-
else:
1308-
batch_size = global_orient.shape[0]
13091322
if body_pose is None:
13101323
body_pose = torch.eye(3, device=device, dtype=dtype).view(
13111324
1, 1, 3, 3).expand(
@@ -1356,7 +1369,7 @@ def forward(
13561369
lmk_faces_idx = self.lmk_faces_idx.unsqueeze(
13571370
dim=0).expand(batch_size, -1).contiguous()
13581371
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(
1359-
self.batch_size, 1, 1)
1372+
batch_size, 1, 1)
13601373
if self.use_face_contour:
13611374
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(
13621375
vertices, full_pose,

0 commit comments

Comments
 (0)