forked from GuyTevet/motion-diffusion-model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vis_utils.py
66 lines (59 loc) · 3.13 KB
/
vis_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from model.rotation2xyz import Rotation2xyz
import numpy as np
from trimesh import Trimesh
import os
import torch
from visualize.simplify_loc2rot import joints2smpl
class npy2obj:
def __init__(self, npy_path, sample_idx, rep_idx, device=0, cuda=True):
self.npy_path = npy_path
self.motions = np.load(self.npy_path, allow_pickle=True)
if self.npy_path.endswith('.npz'):
self.motions = self.motions['arr_0']
self.motions = self.motions[None][0]
self.rot2xyz = Rotation2xyz(device='cpu')
self.faces = self.rot2xyz.smpl_model.faces
self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape
self.opt_cache = {}
self.sample_idx = sample_idx
self.total_num_samples = self.motions['num_samples']
self.rep_idx = rep_idx
self.absl_idx = self.rep_idx*self.total_num_samples + self.sample_idx
self.num_frames = self.motions['motion'][self.absl_idx].shape[-1]
self.j2s = joints2smpl(num_frames=self.num_frames, device_id=device, cuda=cuda)
if self.nfeats == 3:
print(f'Running SMPLify For sample [{sample_idx}], repetition [{rep_idx}], it may take a few minutes.')
motion_tensor, opt_dict = self.j2s.joint2smpl(self.motions['motion'][self.absl_idx].transpose(2, 0, 1)) # [nframes, njoints, 3]
self.motions['motion'] = motion_tensor.cpu().numpy()
elif self.nfeats == 6:
self.motions['motion'] = self.motions['motion'][[self.absl_idx]]
self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape
self.real_num_frames = self.motions['lengths'][self.absl_idx]
self.vertices = self.rot2xyz(torch.tensor(self.motions['motion']), mask=None,
pose_rep='rot6d', translation=True, glob=True,
jointstype='vertices',
# jointstype='smpl', # for joint locations
vertstrans=True)
self.root_loc = self.motions['motion'][:, -1, :3, :].reshape(1, 1, 3, -1)
# self.vertices += self.root_loc
def get_vertices(self, sample_i, frame_i):
return self.vertices[sample_i, :, :, frame_i].squeeze().tolist()
def get_trimesh(self, sample_i, frame_i):
return Trimesh(vertices=self.get_vertices(sample_i, frame_i),
faces=self.faces)
def save_obj(self, save_path, frame_i):
mesh = self.get_trimesh(0, frame_i)
with open(save_path, 'w') as fw:
mesh.export(fw, 'obj')
return save_path
def save_npy(self, save_path):
data_dict = {
'motion': self.motions['motion'][0, :, :, :self.real_num_frames],
'thetas': self.motions['motion'][0, :-1, :, :self.real_num_frames],
'root_translation': self.motions['motion'][0, -1, :3, :self.real_num_frames],
'faces': self.faces,
'vertices': self.vertices[0, :, :, :self.real_num_frames],
'text': self.motions['text'][0],
'length': self.real_num_frames,
}
np.save(save_path, data_dict)