|
30 | 30 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
|
31 | 31 | dynamic_lmk_b_coords,
|
32 | 32 | neck_kin_chain, dtype=torch.float32):
|
| 33 | + ''' Compute the faces, barycentric coordinates for the dynamic landmarks |
| 34 | +
|
| 35 | +
|
| 36 | + To do so, we first compute the rotation of the neck around the y-axis |
| 37 | + and then use a pre-computed look-up table to find the faces and the |
| 38 | + barycentric coordinates that will be used. |
| 39 | +
|
| 40 | + Special thanks to Soubhik Sanyal ([email protected]) |
| 41 | + for providing the original TensorFlow implementation and for the LUT. |
| 42 | +
|
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + vertices: torch.tensor BxVx3, dtype = torch.float32 |
| 46 | + The tensor of input vertices |
| 47 | + pose: torch.tensor Bx(Jx3), dtype = torch.float32 |
| 48 | + The current pose of the body model |
| 49 | + dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long |
| 50 | + The look-up table from neck rotation to faces |
| 51 | + dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 |
| 52 | + The look-up table from neck rotation to barycentric coordinates |
| 53 | + neck_kin_chain: list |
| 54 | + A python list that contains the indices of the joints that form the |
| 55 | + kinematic chain of the neck. |
| 56 | + dtype: torch.dtype, optional |
| 57 | +
|
| 58 | + Returns |
| 59 | + ------- |
| 60 | + dyn_lmk_faces_idx: torch.tensor, dtype = torch.long |
| 61 | + A tensor of size BxL that contains the indices of the faces that |
| 62 | + will be used to compute the current dynamic landmarks. |
| 63 | + dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 |
| 64 | + A tensor of size BxL that contains the indices of the faces that |
| 65 | + will be used to compute the current dynamic landmarks. |
| 66 | + ''' |
33 | 67 |
|
34 | 68 | batch_size = vertices.shape[0]
|
35 | 69 |
|
@@ -84,12 +118,17 @@ def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
|
84 | 118 | # Extract the indices of the vertices for each face
|
85 | 119 | # BxLx3
|
86 | 120 | batch_size, num_verts = vertices.shape[:2]
|
| 121 | + device = vertices.device |
| 122 | + |
87 | 123 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
|
88 |
| - 1, -1, 3).repeat([batch_size, 1, 1]) |
89 |
| - lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to( |
90 |
| - device=vertices.device) * num_verts |
| 124 | + batch_size, -1, 3) |
| 125 | + |
| 126 | + lmk_faces += torch.arange( |
| 127 | + batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts |
| 128 | + |
| 129 | + lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( |
| 130 | + batch_size, -1, 3, 3) |
91 | 131 |
|
92 |
| - lmk_vertices = vertices.view(-1, 3)[lmk_faces] |
93 | 132 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
|
94 | 133 | return landmarks
|
95 | 134 |
|
|
0 commit comments