@@ -431,13 +431,16 @@ def forward(
431
431
Returns
432
432
-------
433
433
'''
434
+ model_vars = [betas , global_orient , body_pose , transl ]
435
+ batch_size = 1
436
+ for var in model_vars :
437
+ if var is None :
438
+ continue
439
+ batch_size = max (batch_size , len (var ))
434
440
device , dtype = self .shapedirs .device , self .shapedirs .dtype
435
441
if global_orient is None :
436
- batch_size = 1
437
442
global_orient = torch .eye (3 , device = device , dtype = dtype ).view (
438
443
1 , 1 , 3 , 3 ).expand (batch_size , - 1 , - 1 , - 1 ).contiguous ()
439
- else :
440
- batch_size = global_orient .shape [0 ]
441
444
if body_pose is None :
442
445
body_pose = torch .eye (3 , device = device , dtype = dtype ).view (
443
446
1 , 1 , 3 , 3 ).expand (
@@ -675,6 +678,7 @@ def forward(
675
678
) -> SMPLHOutput :
676
679
'''
677
680
'''
681
+
678
682
# If no shape and pose parameters are passed along, then use the
679
683
# ones from the module
680
684
global_orient = (global_orient if global_orient is not None else
@@ -702,7 +706,7 @@ def forward(
702
706
right_hand_pose ], dim = 1 )
703
707
full_pose += self .pose_mean
704
708
705
- vertices , joints = lbs (self . betas , full_pose , self .v_template ,
709
+ vertices , joints = lbs (betas , full_pose , self .v_template ,
706
710
self .shapedirs , self .posedirs ,
707
711
self .J_regressor , self .parents ,
708
712
self .lbs_weights , pose2rot = pose2rot )
@@ -760,13 +764,17 @@ def forward(
760
764
) -> SMPLHOutput :
761
765
'''
762
766
'''
767
+ model_vars = [betas , global_orient , body_pose , transl , left_hand_pose ,
768
+ right_hand_pose ]
769
+ batch_size = 1
770
+ for var in model_vars :
771
+ if var is None :
772
+ continue
773
+ batch_size = max (batch_size , len (var ))
763
774
device , dtype = self .shapedirs .device , self .shapedirs .dtype
764
775
if global_orient is None :
765
- batch_size = 1
766
776
global_orient = torch .eye (3 , device = device , dtype = dtype ).view (
767
777
1 , 1 , 3 , 3 ).expand (batch_size , - 1 , - 1 , - 1 ).contiguous ()
768
- else :
769
- batch_size = global_orient .shape [0 ]
770
778
if body_pose is None :
771
779
body_pose = torch .eye (3 , device = device , dtype = dtype ).view (
772
780
1 , 1 , 3 , 3 ).expand (batch_size , 21 , - 1 , - 1 ).contiguous ()
@@ -1300,12 +1308,17 @@ def forward(
1300
1308
'''
1301
1309
device , dtype = self .shapedirs .device , self .shapedirs .dtype
1302
1310
1311
+ model_vars = [betas , global_orient , body_pose , transl ,
1312
+ expression , left_hand_pose , right_hand_pose , jaw_pose ]
1313
+ batch_size = 1
1314
+ for var in model_vars :
1315
+ if var is None :
1316
+ continue
1317
+ batch_size = max (batch_size , len (var ))
1318
+
1303
1319
if global_orient is None :
1304
- batch_size = 1
1305
1320
global_orient = torch .eye (3 , device = device , dtype = dtype ).view (
1306
1321
1 , 1 , 3 , 3 ).expand (batch_size , - 1 , - 1 , - 1 ).contiguous ()
1307
- else :
1308
- batch_size = global_orient .shape [0 ]
1309
1322
if body_pose is None :
1310
1323
body_pose = torch .eye (3 , device = device , dtype = dtype ).view (
1311
1324
1 , 1 , 3 , 3 ).expand (
@@ -1356,7 +1369,7 @@ def forward(
1356
1369
lmk_faces_idx = self .lmk_faces_idx .unsqueeze (
1357
1370
dim = 0 ).expand (batch_size , - 1 ).contiguous ()
1358
1371
lmk_bary_coords = self .lmk_bary_coords .unsqueeze (dim = 0 ).repeat (
1359
- self . batch_size , 1 , 1 )
1372
+ batch_size , 1 , 1 )
1360
1373
if self .use_face_contour :
1361
1374
lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords (
1362
1375
vertices , full_pose ,
0 commit comments