Skip to content

Commit b715c4f

Browse files
committed
Add type hints to lbs
1 parent bf6e84a commit b715c4f

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

smplx/lbs.py

+42-12
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,23 @@
1818
from __future__ import print_function
1919
from __future__ import division
2020

21+
from typing import Tuple, List
2122
import numpy as np
2223

2324
import torch
2425
import torch.nn.functional as F
2526

26-
from .utils import rot_mat_to_euler
27+
from .utils import rot_mat_to_euler, Tensor
2728

2829

29-
def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
30-
dynamic_lmk_b_coords,
31-
neck_kin_chain, dtype=torch.float32):
30+
def find_dynamic_lmk_idx_and_bcoords(
31+
vertices: Tensor,
32+
pose: Tensor,
33+
dynamic_lmk_faces_idx: Tensor,
34+
dynamic_lmk_b_coords: Tensor,
35+
neck_kin_chain: List[int],
36+
dtype=torch.float32
37+
) -> Tuple[Tensor, Tensor]:
3238
''' Compute the faces, barycentric coordinates for the dynamic landmarks
3339
3440
@@ -94,7 +100,12 @@ def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
94100
return dyn_lmk_faces_idx, dyn_lmk_b_coords
95101

96102

97-
def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
103+
def vertices2landmarks(
104+
vertices: Tensor,
105+
faces: Tensor,
106+
lmk_faces_idx: Tensor,
107+
lmk_bary_coords: Tensor
108+
) -> Tensor:
98109
''' Calculates landmarks by barycentric interpolation
99110
100111
Parameters
@@ -133,8 +144,18 @@ def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
133144
return landmarks
134145

135146

136-
def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
137-
lbs_weights, pose2rot=True, dtype=torch.float32):
147+
def lbs(
148+
betas: Tensor,
149+
pose: Tensor,
150+
v_template: Tensor,
151+
shapedirs: Tensor,
152+
posedirs: Tensor,
153+
J_regressor: Tensor,
154+
parents: Tensor,
155+
lbs_weights: Tensor,
156+
pose2rot: bool = True,
157+
dtype=torch.float32
158+
) -> Tuple[Tensor, Tensor]:
138159
''' Performs Linear Blend Skinning with the given shape and pose parameters
139160
140161
Parameters
@@ -223,7 +244,7 @@ def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
223244
return verts, J_transformed
224245

225246

226-
def vertices2joints(J_regressor, vertices):
247+
def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
227248
''' Calculates the 3D joint locations from the vertices
228249
229250
Parameters
@@ -243,7 +264,7 @@ def vertices2joints(J_regressor, vertices):
243264
return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
244265

245266

246-
def blend_shapes(betas, shape_disps):
267+
def blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor:
247268
''' Calculates the per vertex displacement due to the blend shapes
248269
249270
@@ -267,7 +288,11 @@ def blend_shapes(betas, shape_disps):
267288
return blend_shape
268289

269290

270-
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
291+
def batch_rodrigues(
292+
rot_vecs: Tensor,
293+
epsilon: float = 1e-8,
294+
dtype=torch.float32
295+
) -> Tensor:
271296
''' Calculates the rotation matrices for a batch of rotation vectors
272297
Parameters
273298
----------
@@ -301,7 +326,7 @@ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
301326
return rot_mat
302327

303328

304-
def transform_mat(R, t):
329+
def transform_mat(R: Tensor, t: Tensor) -> Tensor:
305330
''' Creates a batch of transformation matrices
306331
Args:
307332
- R: Bx3x3 array of a batch of rotation matrices
@@ -314,7 +339,12 @@ def transform_mat(R, t):
314339
F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
315340

316341

317-
def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
342+
def batch_rigid_transform(
343+
rot_mats: Tensor,
344+
joints: Tensor,
345+
parents: Tensor,
346+
dtype=torch.float32
347+
) -> Tensor:
318348
"""
319349
Applies a batch of rigid transformations to the joints
320350

0 commit comments

Comments
 (0)