forked from GuyTevet/motion-diffusion-model
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b0f9721
commit bb8960a
Showing
18 changed files
with
771 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
drink | ||
lift_dumbbell |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
jumping-jack | ||
left-lunging | ||
left-stretching | ||
raising-hand-and-jumping | ||
rotation-clapping | ||
front-raising | ||
pulling-chest-expanders |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
# from utils.action_label_to_idx import action_label_to_idx | ||
from data_loaders.tensors import collate | ||
from utils.misc import to_torch | ||
import utils.rotation_conversions as geometry | ||
|
||
class Dataset(torch.utils.data.Dataset): | ||
def __init__(self, num_frames=1, sampling="conseq", sampling_step=1, split="train", | ||
pose_rep="rot6d", translation=True, glob=True, max_len=-1, min_len=-1, num_seq_max=-1, **kwargs): | ||
self.num_frames = num_frames | ||
self.sampling = sampling | ||
self.sampling_step = sampling_step | ||
self.split = split | ||
self.pose_rep = pose_rep | ||
self.translation = translation | ||
self.glob = glob | ||
self.max_len = max_len | ||
self.min_len = min_len | ||
self.num_seq_max = num_seq_max | ||
|
||
self.align_pose_frontview = kwargs.get('align_pose_frontview', False) | ||
self.use_action_cat_as_text_labels = kwargs.get('use_action_cat_as_text_labels', False) | ||
self.only_60_classes = kwargs.get('only_60_classes', False) | ||
self.leave_out_15_classes = kwargs.get('leave_out_15_classes', False) | ||
self.use_only_15_classes = kwargs.get('use_only_15_classes', False) | ||
|
||
if self.split not in ["train", "val", "test"]: | ||
raise ValueError(f"{self.split} is not a valid split") | ||
|
||
super().__init__() | ||
|
||
# to remove shuffling | ||
self._original_train = None | ||
self._original_test = None | ||
|
||
def action_to_label(self, action): | ||
return self._action_to_label[action] | ||
|
||
def label_to_action(self, label): | ||
import numbers | ||
if isinstance(label, numbers.Integral): | ||
return self._label_to_action[label] | ||
else: # if it is one hot vector | ||
label = np.argmax(label) | ||
return self._label_to_action[label] | ||
|
||
def get_pose_data(self, data_index, frame_ix): | ||
pose = self._load(data_index, frame_ix) | ||
label = self.get_label(data_index) | ||
return pose, label | ||
|
||
def get_label(self, ind): | ||
action = self.get_action(ind) | ||
return self.action_to_label(action) | ||
|
||
def get_action(self, ind): | ||
return self._actions[ind] | ||
|
||
def action_to_action_name(self, action): | ||
return self._action_classes[action] | ||
|
||
def action_name_to_action(self, action_name): | ||
# self._action_classes is either a list or a dictionary. If it's a dictionary, we 1st convert it to a list | ||
all_action_names = self._action_classes | ||
if isinstance(all_action_names, dict): | ||
all_action_names = list(all_action_names.values()) | ||
assert list(self._action_classes.keys()) == list(range(len(all_action_names))) # the keys should be ordered from 0 to num_actions | ||
|
||
sorter = np.argsort(all_action_names) | ||
actions = sorter[np.searchsorted(all_action_names, action_name, sorter=sorter)] | ||
return actions | ||
|
||
def __getitem__(self, index): | ||
if self.split == 'train': | ||
data_index = self._train[index] | ||
else: | ||
data_index = self._test[index] | ||
|
||
# inp, target = self._get_item_data_index(data_index) | ||
# return inp, target | ||
return self._get_item_data_index(data_index) | ||
|
||
def _load(self, ind, frame_ix): | ||
pose_rep = self.pose_rep | ||
if pose_rep == "xyz" or self.translation: | ||
if getattr(self, "_load_joints3D", None) is not None: | ||
# Locate the root joint of initial pose at origin | ||
joints3D = self._load_joints3D(ind, frame_ix) | ||
joints3D = joints3D - joints3D[0, 0, :] | ||
ret = to_torch(joints3D) | ||
if self.translation: | ||
ret_tr = ret[:, 0, :] | ||
else: | ||
if pose_rep == "xyz": | ||
raise ValueError("This representation is not possible.") | ||
if getattr(self, "_load_translation") is None: | ||
raise ValueError("Can't extract translations.") | ||
ret_tr = self._load_translation(ind, frame_ix) | ||
ret_tr = to_torch(ret_tr - ret_tr[0]) | ||
|
||
if pose_rep != "xyz": | ||
if getattr(self, "_load_rotvec", None) is None: | ||
raise ValueError("This representation is not possible.") | ||
else: | ||
pose = self._load_rotvec(ind, frame_ix) | ||
if not self.glob: | ||
pose = pose[:, 1:, :] | ||
pose = to_torch(pose) | ||
if self.align_pose_frontview: | ||
first_frame_root_pose_matrix = geometry.axis_angle_to_matrix(pose[0][0]) | ||
all_root_poses_matrix = geometry.axis_angle_to_matrix(pose[:, 0, :]) | ||
aligned_root_poses_matrix = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1), | ||
all_root_poses_matrix) | ||
pose[:, 0, :] = geometry.matrix_to_axis_angle(aligned_root_poses_matrix) | ||
|
||
if self.translation: | ||
ret_tr = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1).float(), | ||
torch.transpose(ret_tr, 0, 1)) | ||
ret_tr = torch.transpose(ret_tr, 0, 1) | ||
|
||
if pose_rep == "rotvec": | ||
ret = pose | ||
elif pose_rep == "rotmat": | ||
ret = geometry.axis_angle_to_matrix(pose).view(*pose.shape[:2], 9) | ||
elif pose_rep == "rotquat": | ||
ret = geometry.axis_angle_to_quaternion(pose) | ||
elif pose_rep == "rot6d": | ||
ret = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(pose)) | ||
if pose_rep != "xyz" and self.translation: | ||
padded_tr = torch.zeros((ret.shape[0], ret.shape[2]), dtype=ret.dtype) | ||
padded_tr[:, :3] = ret_tr | ||
ret = torch.cat((ret, padded_tr[:, None]), 1) | ||
ret = ret.permute(1, 2, 0).contiguous() | ||
return ret.float() | ||
|
||
def _get_item_data_index(self, data_index): | ||
nframes = self._num_frames_in_video[data_index] | ||
|
||
if self.num_frames == -1 and (self.max_len == -1 or nframes <= self.max_len): | ||
frame_ix = np.arange(nframes) | ||
else: | ||
if self.num_frames == -2: | ||
if self.min_len <= 0: | ||
raise ValueError("You should put a min_len > 0 for num_frames == -2 mode") | ||
if self.max_len != -1: | ||
max_frame = min(nframes, self.max_len) | ||
else: | ||
max_frame = nframes | ||
|
||
num_frames = random.randint(self.min_len, max(max_frame, self.min_len)) | ||
else: | ||
num_frames = self.num_frames if self.num_frames != -1 else self.max_len | ||
|
||
if num_frames > nframes: | ||
fair = False # True | ||
if fair: | ||
# distills redundancy everywhere | ||
choices = np.random.choice(range(nframes), | ||
num_frames, | ||
replace=True) | ||
frame_ix = sorted(choices) | ||
else: | ||
# adding the last frame until done | ||
ntoadd = max(0, num_frames - nframes) | ||
lastframe = nframes - 1 | ||
padding = lastframe * np.ones(ntoadd, dtype=int) | ||
frame_ix = np.concatenate((np.arange(0, nframes), | ||
padding)) | ||
|
||
elif self.sampling in ["conseq", "random_conseq"]: | ||
step_max = (nframes - 1) // (num_frames - 1) | ||
if self.sampling == "conseq": | ||
if self.sampling_step == -1 or self.sampling_step * (num_frames - 1) >= nframes: | ||
step = step_max | ||
else: | ||
step = self.sampling_step | ||
elif self.sampling == "random_conseq": | ||
step = random.randint(1, step_max) | ||
|
||
lastone = step * (num_frames - 1) | ||
shift_max = nframes - lastone - 1 | ||
shift = random.randint(0, max(0, shift_max - 1)) | ||
frame_ix = shift + np.arange(0, lastone + 1, step) | ||
|
||
elif self.sampling == "random": | ||
choices = np.random.choice(range(nframes), | ||
num_frames, | ||
replace=False) | ||
frame_ix = sorted(choices) | ||
|
||
else: | ||
raise ValueError("Sampling not recognized.") | ||
|
||
inp, action = self.get_pose_data(data_index, frame_ix) | ||
|
||
|
||
output = {'inp': inp, 'action': action} | ||
|
||
if hasattr(self, '_actions') and hasattr(self, '_action_classes'): | ||
output['action_text'] = self.action_to_action_name(self.get_action(data_index)) | ||
|
||
return output | ||
|
||
|
||
def get_mean_length_label(self, label): | ||
if self.num_frames != -1: | ||
return self.num_frames | ||
|
||
if self.split == 'train': | ||
index = self._train | ||
else: | ||
index = self._test | ||
|
||
action = self.label_to_action(label) | ||
choices = np.argwhere(self._actions[index] == action).squeeze(1) | ||
lengths = self._num_frames_in_video[np.array(index)[choices]] | ||
|
||
if self.max_len == -1: | ||
return np.mean(lengths) | ||
else: | ||
# make the lengths less than max_len | ||
lengths[lengths > self.max_len] = self.max_len | ||
return np.mean(lengths) | ||
|
||
def __len__(self): | ||
num_seq_max = getattr(self, "num_seq_max", -1) | ||
if num_seq_max == -1: | ||
from math import inf | ||
num_seq_max = inf | ||
|
||
if self.split == 'train': | ||
return min(len(self._train), num_seq_max) | ||
else: | ||
return min(len(self._test), num_seq_max) |
Oops, something went wrong.