From bb8960a51a012fb8ada58d98a1546e57fce25b75 Mon Sep 17 00:00:00 2001 From: Sigal Raab Date: Wed, 12 Oct 2022 11:54:05 +0300 Subject: [PATCH] support a2m sampling --- README.md | 72 ++++++- assets/example_action_names_humanact12.txt | 2 + assets/example_action_names_uestc.txt | 7 + data_loaders/a2m/dataset.py | 237 +++++++++++++++++++++ data_loaders/a2m/humanact12poses.py | 57 +++++ data_loaders/a2m/uestc.py | 227 ++++++++++++++++++++ data_loaders/get_data.py | 4 +- data_loaders/humanml/utils/plot_script.py | 11 +- data_loaders/tensors.py | 6 +- diffusion/gaussian_diffusion.py | 6 +- model/cfg_sampler.py | 6 +- model/mdm.py | 53 ++--- prepare/download_a2m_datasets.sh | 22 ++ sample.py | 93 ++++---- train/train_mdm.py | 2 +- utils/misc.py | 40 ++++ utils/model_util.py | 13 +- utils/parser_util.py | 8 + 18 files changed, 771 insertions(+), 95 deletions(-) create mode 100644 assets/example_action_names_humanact12.txt create mode 100644 assets/example_action_names_uestc.txt create mode 100644 data_loaders/a2m/dataset.py create mode 100644 data_loaders/a2m/humanact12poses.py create mode 100644 data_loaders/a2m/uestc.py create mode 100644 prepare/download_a2m_datasets.sh create mode 100644 utils/misc.py diff --git a/README.md b/README.md index 9ed2608e..985587b3 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,9 @@ bash prepare/download_glove.sh ### 2. Get data +
+ Text to Motion + There are two paths to get the data: (a) **Go the easy way if** you just want to generate text-to-motion (excluding editing which does require motion capture data) @@ -102,11 +105,25 @@ cp -r ../HumanML3D/HumanML3D ./dataset/HumanML3D ``` **KIT** - Download from [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git) (no processing needed this time) and the place result in `./dataset/KIT-ML` +
+ +
+ Action to Motion +**UESTC, HumanAct12** : +```bash +bash prepare/download_a2m_datasets.sh +``` +
### 3. Download the pretrained models -Download the model(s) you wish to use, then unzip and place it in `./save/`. **For text-to-motion, you need only the first one.** +Download the model(s) you wish to use, then unzip and place them in `./save/`. + +
+ Text to Motion + +**You need only the first one.** **HumanML3D** @@ -120,7 +137,29 @@ Download the model(s) you wish to use, then unzip and place it in `./save/`. **F [kit-encoder-512](https://drive.google.com/file/d/1SHCRcE0es31vkJMLGf9dyLe7YsWj7pNL/view?usp=sharing) -## Generate text-to-motion +
+ +
+ Action to Motion + +**UESTC** + +[uestc](https://drive.google.com/file/d/1goB2DJK4B-fLu2QmqGWKAqWGMTAO6wQ6/view?usp=sharing) + +[uestc_no_fc](https://drive.google.com/file/d/1fpv3mR-qP9CYCsi9CrQhFqlLavcSQky6/view?usp=sharing) + +**HumanAct12** + +[humanact12](https://drive.google.com/file/d/154X8_Lgpec6Xj0glEGql7FVKqPYCdBFO/view?usp=sharing) + +[humanact12_no_fc](https://drive.google.com/file/d/1frKVMBYNiN5Mlq7zsnhDBzs9vGJvFeiQ/view?usp=sharing) + +
+ + +## Motion Synthesis +
+ Text to Motion ### Generate from test set prompts @@ -139,11 +178,35 @@ python -m sample --model_path ./save/humanml_trans_enc_512/model000200000.pt --i ```shell python -m sample --model_path ./save/humanml_trans_enc_512/model000200000.pt --text_prompt "the person walked forward and is picking up his toolbox." ``` +
+ +
+ Action to Motion + +### Generate from test set actions + +```shell +python -m sample --model_path ./save/humanact12/model000350000.pt --num_samples 10 --num_repetitions 3 +``` + +### Generate from your actions file + +```shell +python -m sample --model_path ./save/humanact12/model000350000.pt --action_file ./assets/example_action_names_humanact12.txt +``` + +### Generate a single prompt + +```shell +python -m sample --model_path ./save/humanact12/model000350000.pt --text_prompt "drink" +``` +
+ -**You can also define:** +**You may also define:** * `--device` id. * `--seed` to sample different prompts. -* `--motion_length` in seconds (maximum is 9.8[sec]). +* `--motion_length` (text-to-motion only) in seconds (maximum is 9.8[sec]). **Running those will get you:** @@ -191,6 +254,7 @@ python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset hum ```shell python -m train.train_mdm --save_dir save/my_kit_trans_enc_512 --dataset kit ``` + * Use `--device` to define GPU id. * Use `--arch` to choose one of the architectures reported in the paper `{trans_enc, trans_dec, gru}` (`trans_enc` is default). diff --git a/assets/example_action_names_humanact12.txt b/assets/example_action_names_humanact12.txt new file mode 100644 index 00000000..31b0728a --- /dev/null +++ b/assets/example_action_names_humanact12.txt @@ -0,0 +1,2 @@ +drink +lift_dumbbell diff --git a/assets/example_action_names_uestc.txt b/assets/example_action_names_uestc.txt new file mode 100644 index 00000000..a3095abd --- /dev/null +++ b/assets/example_action_names_uestc.txt @@ -0,0 +1,7 @@ +jumping-jack +left-lunging +left-stretching +raising-hand-and-jumping +rotation-clapping +front-raising +pulling-chest-expanders diff --git a/data_loaders/a2m/dataset.py b/data_loaders/a2m/dataset.py new file mode 100644 index 00000000..df5bc6a5 --- /dev/null +++ b/data_loaders/a2m/dataset.py @@ -0,0 +1,237 @@ +import random + +import numpy as np +import torch +# from utils.action_label_to_idx import action_label_to_idx +from data_loaders.tensors import collate +from utils.misc import to_torch +import utils.rotation_conversions as geometry + +class Dataset(torch.utils.data.Dataset): + def __init__(self, num_frames=1, sampling="conseq", sampling_step=1, split="train", + pose_rep="rot6d", translation=True, glob=True, max_len=-1, min_len=-1, num_seq_max=-1, **kwargs): + self.num_frames = num_frames + self.sampling = sampling + self.sampling_step = sampling_step + self.split = split + self.pose_rep = pose_rep + self.translation = translation + self.glob = glob + self.max_len = max_len + self.min_len = min_len + self.num_seq_max = num_seq_max + + self.align_pose_frontview = kwargs.get('align_pose_frontview', False) + self.use_action_cat_as_text_labels = kwargs.get('use_action_cat_as_text_labels', False) + self.only_60_classes = kwargs.get('only_60_classes', False) + self.leave_out_15_classes = kwargs.get('leave_out_15_classes', False) + self.use_only_15_classes = kwargs.get('use_only_15_classes', False) + + if self.split not in ["train", "val", "test"]: + raise ValueError(f"{self.split} is not a valid split") + + super().__init__() + + # to remove shuffling + self._original_train = None + self._original_test = None + + def action_to_label(self, action): + return self._action_to_label[action] + + def label_to_action(self, label): + import numbers + if isinstance(label, numbers.Integral): + return self._label_to_action[label] + else: # if it is one hot vector + label = np.argmax(label) + return self._label_to_action[label] + + def get_pose_data(self, data_index, frame_ix): + pose = self._load(data_index, frame_ix) + label = self.get_label(data_index) + return pose, label + + def get_label(self, ind): + action = self.get_action(ind) + return self.action_to_label(action) + + def get_action(self, ind): + return self._actions[ind] + + def action_to_action_name(self, action): + return self._action_classes[action] + + def action_name_to_action(self, action_name): + # self._action_classes is either a list or a dictionary. If it's a dictionary, we 1st convert it to a list + all_action_names = self._action_classes + if isinstance(all_action_names, dict): + all_action_names = list(all_action_names.values()) + assert list(self._action_classes.keys()) == list(range(len(all_action_names))) # the keys should be ordered from 0 to num_actions + + sorter = np.argsort(all_action_names) + actions = sorter[np.searchsorted(all_action_names, action_name, sorter=sorter)] + return actions + + def __getitem__(self, index): + if self.split == 'train': + data_index = self._train[index] + else: + data_index = self._test[index] + + # inp, target = self._get_item_data_index(data_index) + # return inp, target + return self._get_item_data_index(data_index) + + def _load(self, ind, frame_ix): + pose_rep = self.pose_rep + if pose_rep == "xyz" or self.translation: + if getattr(self, "_load_joints3D", None) is not None: + # Locate the root joint of initial pose at origin + joints3D = self._load_joints3D(ind, frame_ix) + joints3D = joints3D - joints3D[0, 0, :] + ret = to_torch(joints3D) + if self.translation: + ret_tr = ret[:, 0, :] + else: + if pose_rep == "xyz": + raise ValueError("This representation is not possible.") + if getattr(self, "_load_translation") is None: + raise ValueError("Can't extract translations.") + ret_tr = self._load_translation(ind, frame_ix) + ret_tr = to_torch(ret_tr - ret_tr[0]) + + if pose_rep != "xyz": + if getattr(self, "_load_rotvec", None) is None: + raise ValueError("This representation is not possible.") + else: + pose = self._load_rotvec(ind, frame_ix) + if not self.glob: + pose = pose[:, 1:, :] + pose = to_torch(pose) + if self.align_pose_frontview: + first_frame_root_pose_matrix = geometry.axis_angle_to_matrix(pose[0][0]) + all_root_poses_matrix = geometry.axis_angle_to_matrix(pose[:, 0, :]) + aligned_root_poses_matrix = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1), + all_root_poses_matrix) + pose[:, 0, :] = geometry.matrix_to_axis_angle(aligned_root_poses_matrix) + + if self.translation: + ret_tr = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1).float(), + torch.transpose(ret_tr, 0, 1)) + ret_tr = torch.transpose(ret_tr, 0, 1) + + if pose_rep == "rotvec": + ret = pose + elif pose_rep == "rotmat": + ret = geometry.axis_angle_to_matrix(pose).view(*pose.shape[:2], 9) + elif pose_rep == "rotquat": + ret = geometry.axis_angle_to_quaternion(pose) + elif pose_rep == "rot6d": + ret = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(pose)) + if pose_rep != "xyz" and self.translation: + padded_tr = torch.zeros((ret.shape[0], ret.shape[2]), dtype=ret.dtype) + padded_tr[:, :3] = ret_tr + ret = torch.cat((ret, padded_tr[:, None]), 1) + ret = ret.permute(1, 2, 0).contiguous() + return ret.float() + + def _get_item_data_index(self, data_index): + nframes = self._num_frames_in_video[data_index] + + if self.num_frames == -1 and (self.max_len == -1 or nframes <= self.max_len): + frame_ix = np.arange(nframes) + else: + if self.num_frames == -2: + if self.min_len <= 0: + raise ValueError("You should put a min_len > 0 for num_frames == -2 mode") + if self.max_len != -1: + max_frame = min(nframes, self.max_len) + else: + max_frame = nframes + + num_frames = random.randint(self.min_len, max(max_frame, self.min_len)) + else: + num_frames = self.num_frames if self.num_frames != -1 else self.max_len + + if num_frames > nframes: + fair = False # True + if fair: + # distills redundancy everywhere + choices = np.random.choice(range(nframes), + num_frames, + replace=True) + frame_ix = sorted(choices) + else: + # adding the last frame until done + ntoadd = max(0, num_frames - nframes) + lastframe = nframes - 1 + padding = lastframe * np.ones(ntoadd, dtype=int) + frame_ix = np.concatenate((np.arange(0, nframes), + padding)) + + elif self.sampling in ["conseq", "random_conseq"]: + step_max = (nframes - 1) // (num_frames - 1) + if self.sampling == "conseq": + if self.sampling_step == -1 or self.sampling_step * (num_frames - 1) >= nframes: + step = step_max + else: + step = self.sampling_step + elif self.sampling == "random_conseq": + step = random.randint(1, step_max) + + lastone = step * (num_frames - 1) + shift_max = nframes - lastone - 1 + shift = random.randint(0, max(0, shift_max - 1)) + frame_ix = shift + np.arange(0, lastone + 1, step) + + elif self.sampling == "random": + choices = np.random.choice(range(nframes), + num_frames, + replace=False) + frame_ix = sorted(choices) + + else: + raise ValueError("Sampling not recognized.") + + inp, action = self.get_pose_data(data_index, frame_ix) + + + output = {'inp': inp, 'action': action} + + if hasattr(self, '_actions') and hasattr(self, '_action_classes'): + output['action_text'] = self.action_to_action_name(self.get_action(data_index)) + + return output + + + def get_mean_length_label(self, label): + if self.num_frames != -1: + return self.num_frames + + if self.split == 'train': + index = self._train + else: + index = self._test + + action = self.label_to_action(label) + choices = np.argwhere(self._actions[index] == action).squeeze(1) + lengths = self._num_frames_in_video[np.array(index)[choices]] + + if self.max_len == -1: + return np.mean(lengths) + else: + # make the lengths less than max_len + lengths[lengths > self.max_len] = self.max_len + return np.mean(lengths) + + def __len__(self): + num_seq_max = getattr(self, "num_seq_max", -1) + if num_seq_max == -1: + from math import inf + num_seq_max = inf + + if self.split == 'train': + return min(len(self._train), num_seq_max) + else: + return min(len(self._test), num_seq_max) diff --git a/data_loaders/a2m/humanact12poses.py b/data_loaders/a2m/humanact12poses.py new file mode 100644 index 00000000..d9b8894a --- /dev/null +++ b/data_loaders/a2m/humanact12poses.py @@ -0,0 +1,57 @@ +import pickle as pkl +import numpy as np +import os +from .dataset import Dataset + + +class HumanAct12Poses(Dataset): + dataname = "humanact12" + + def __init__(self, datapath="dataset/HumanAct12Poses", split="train", **kargs): + self.datapath = datapath + + super().__init__(**kargs) + + pkldatafilepath = os.path.join(datapath, "humanact12poses.pkl") + data = pkl.load(open(pkldatafilepath, "rb")) + + self._pose = [x for x in data["poses"]] + self._num_frames_in_video = [p.shape[0] for p in self._pose] + self._joints = [x for x in data["joints3D"]] + + self._actions = [x for x in data["y"]] + + total_num_actions = 12 + self.num_actions = total_num_actions + + self._train = list(range(len(self._pose))) + + keep_actions = np.arange(0, total_num_actions) + + self._action_to_label = {x: i for i, x in enumerate(keep_actions)} + self._label_to_action = {i: x for i, x in enumerate(keep_actions)} + + self._action_classes = humanact12_coarse_action_enumerator + + def _load_joints3D(self, ind, frame_ix): + return self._joints[ind][frame_ix] + + def _load_rotvec(self, ind, frame_ix): + pose = self._pose[ind][frame_ix].reshape(-1, 24, 3) + return pose + + +humanact12_coarse_action_enumerator = { + 0: "warm_up", + 1: "walk", + 2: "run", + 3: "jump", + 4: "drink", + 5: "lift_dumbbell", + 6: "sit", + 7: "eat", + 8: "turn steering wheel", + 9: "phone", + 10: "boxing", + 11: "throw", +} diff --git a/data_loaders/a2m/uestc.py b/data_loaders/a2m/uestc.py new file mode 100644 index 00000000..3d04f882 --- /dev/null +++ b/data_loaders/a2m/uestc.py @@ -0,0 +1,227 @@ +import os +from tqdm import tqdm +import numpy as np +import pickle as pkl +import utils.rotation_conversions as geometry +import torch + +from .dataset import Dataset +# from torch.utils.data import Dataset + +# from .ntu13 import action2motion_joints +action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38] + + +def get_z(cam_s, cam_pos, joints, img_size, flength): + """ + Solves for the depth offset of the model to approx. orth with persp camera. + """ + # Translate the model itself: Solve the best z that maps to orth_proj points + joints_orth_target = (cam_s * (joints[:, :2] + cam_pos) + 1) * 0.5 * img_size + height3d = np.linalg.norm(np.max(joints[:, :2], axis=0) - np.min(joints[:, :2], axis=0)) + height2d = np.linalg.norm(np.max(joints_orth_target, axis=0) - np.min(joints_orth_target, axis=0)) + tz = np.array(flength * (height3d / height2d)) + return float(tz) + + +def get_trans_from_vibe(vibe, index, use_z=True): + alltrans = [] + for t in range(vibe["joints3d"][index].shape[0]): + # Convert crop cam to orig cam + # No need! Because `convert_crop_cam_to_orig_img` from demoutils of vibe + # does this already for us :) + # Its format is: [sx, sy, tx, ty] + cam_orig = vibe["orig_cam"][index][t] + x = cam_orig[2] + y = cam_orig[3] + if use_z: + z = get_z(cam_s=cam_orig[0], # TODO: There are two scales instead of 1. + cam_pos=cam_orig[2:4], + joints=vibe['joints3d'][index][t], + img_size=540, + flength=500) + # z = 500 / (0.5 * 480 * cam_orig[0]) + else: + z = 0 + trans = [x, y, z] + alltrans.append(trans) + alltrans = np.array(alltrans) + return alltrans - alltrans[0] + + +class UESTC(Dataset): + dataname = "uestc" + + def __init__(self, datapath="dataset/uestc", method_name="vibe", view="all", **kargs): + + self.datapath = datapath + self.method_name = method_name + self.view = view + super().__init__(**kargs) + + # Load pre-computed #frames data + with open(os.path.join(datapath, 'info', 'num_frames_min.txt'), 'r') as f: + num_frames_video = np.asarray([int(s) for s in f.read().splitlines()]) + + # Out of 118 subjects -> 51 training, 67 in test + all_subjects = np.arange(1, 119) + self._tr_subjects = [ + 1, 2, 6, 12, 13, 16, 21, 24, 28, 29, 30, 31, 33, 35, 39, 41, 42, 45, 47, 50, + 52, 54, 55, 57, 59, 61, 63, 64, 67, 69, 70, 71, 73, 77, 81, 84, 86, 87, 88, + 90, 91, 93, 96, 99, 102, 103, 104, 107, 108, 112, 113] + self._test_subjects = [s for s in all_subjects if s not in self._tr_subjects] + + # Load names of 25600 videos + with open(os.path.join(datapath, 'info', 'names.txt'), 'r') as f: + videos = f.read().splitlines() + + self._videos = videos + + if self.method_name == "vibe": + vibe_data_path = os.path.join(datapath, "vibe_cache_refined.pkl") + vibe_data = pkl.load(open(vibe_data_path, "rb")) + + self._pose = vibe_data["pose"] + num_frames_method = [p.shape[0] for p in self._pose] + globpath = os.path.join(datapath, "globtrans_usez.pkl") + + if os.path.exists(globpath): + self._globtrans = pkl.load(open(globpath, "rb")) + else: + self._globtrans = [] + for index in tqdm(range(len(self._pose))): + self._globtrans.append(get_trans_from_vibe(vibe_data, index, use_z=True)) + pkl.dump(self._globtrans, open("globtrans_usez.pkl", "wb")) + self._joints = vibe_data["joints3d"] + self._jointsIx = action2motion_joints + else: + raise ValueError("This method name is not recognized.") + + num_frames_video = np.minimum(num_frames_video, num_frames_method) + num_frames_video = num_frames_video.astype(int) + self._num_frames_in_video = [x for x in num_frames_video] + + N = len(videos) + self._actions = np.zeros(N, dtype=int) + for ind in range(N): + self._actions[ind] = self.parse_action(videos[ind]) + + self._actions = [x for x in self._actions] + + total_num_actions = 40 + self.num_actions = total_num_actions + keep_actions = np.arange(0, total_num_actions) + + self._action_to_label = {x: i for i, x in enumerate(keep_actions)} + self._label_to_action = {i: x for i, x in enumerate(keep_actions)} + self.num_classes = len(keep_actions) + + self._train = [] + self._test = [] + + self.info_actions = [] + + def get_rotation(view): + theta = - view * np.pi/4 + axis = torch.tensor([0, 1, 0], dtype=torch.float) + axisangle = theta*axis + matrix = geometry.axis_angle_to_matrix(axisangle) + return matrix + + # 0 is identity if needed + rotations = {key: get_rotation(key) for key in [0, 1, 2, 3, 4, 5, 6, 7]} + + for index, video in enumerate(tqdm(videos, desc='Preparing UESTC data..')): + act, view, subject, side = self._get_action_view_subject_side(video) + self.info_actions.append({"action": act, + "view": view, + "subject": subject, + "side": side}) + if self.view == "frontview": + if side != 1: + continue + # rotate to front view + if side != 1: + # don't take the view 8 in side 2 + if view == 8: + continue + rotation = rotations[view] + global_matrix = geometry.axis_angle_to_matrix(torch.from_numpy(self._pose[index][:, :3])) + # rotate the global pose + self._pose[index][:, :3] = geometry.matrix_to_axis_angle(rotation @ global_matrix).numpy() + # rotate the joints + self._joints[index] = self._joints[index] @ rotation.T.numpy() + self._globtrans[index] = (self._globtrans[index] @ rotation.T.numpy()) + + # add the global translation to the joints + self._joints[index] = self._joints[index] + self._globtrans[index][:, None] + + if subject in self._tr_subjects: + self._train.append(index) + elif subject in self._test_subjects: + self._test.append(index) + else: + raise ValueError("This subject doesn't belong to any set.") + + # if index > 200: + # break + + # Select only sequences which have a minimum number of frames + if self.num_frames > 0: + threshold = self.num_frames*3/4 + else: + threshold = 0 + + method_extracted_ix = np.where(num_frames_video >= threshold)[0].tolist() + self._train = list(set(self._train) & set(method_extracted_ix)) + # keep the test set without modification + self._test = list(set(self._test)) + + action_classes_file = os.path.join(datapath, "info/action_classes.txt") + with open(action_classes_file, 'r') as f: + self._action_classes = np.array(f.read().splitlines()) + + # with open(processd_path, 'wb') as file: + # pkl.dump(xxx, file) + + def _load_joints3D(self, ind, frame_ix): + if len(self._joints[ind]) == 0: + raise ValueError( + f"Cannot load index {ind} in _load_joints3D function.") + if self._jointsIx is not None: + joints3D = self._joints[ind][frame_ix][:, self._jointsIx] + else: + joints3D = self._joints[ind][frame_ix] + + return joints3D + + def _load_rotvec(self, ind, frame_ix): + # 72 dim smpl + pose = self._pose[ind][frame_ix, :].reshape(-1, 24, 3) + return pose + + def _get_action_view_subject_side(self, videopath): + # TODO: Can be moved to tools.py + spl = videopath.split('_') + action = int(spl[0][1:]) + view = int(spl[1][1:]) + subject = int(spl[2][1:]) + side = int(spl[3][1:]) + return action, view, subject, side + + def _get_videopath(self, action, view, subject, side): + # Unused function + return 'a{:d}_d{:d}_p{:03d}_c{:d}_color.avi'.format( + action, view, subject, side) + + def parse_action(self, path, return_int=True): + # Override parent method + info, _, _, _ = self._get_action_view_subject_side(path) + if return_int: + return int(info) + else: + return info + + +if __name__ == "__main__": + dataset = UESTC() diff --git a/data_loaders/get_data.py b/data_loaders/get_data.py index 270f386b..a2ed501d 100644 --- a/data_loaders/get_data.py +++ b/data_loaders/get_data.py @@ -7,10 +7,10 @@ def get_dataset_class(name): from .amass import AMASS return AMASS elif name == "uestc": - from .uestc import UESTC + from .a2m.uestc import UESTC return UESTC elif name == "humanact12": - from .humanact12poses import HumanAct12Poses + from .a2m.humanact12poses import HumanAct12Poses return HumanAct12Poses elif name == "humanml": from data_loaders.humanml.data.dataset import HumanML3D diff --git a/data_loaders/humanml/utils/plot_script.py b/data_loaders/humanml/utils/plot_script.py index d5349283..dcb16017 100644 --- a/data_loaders/humanml/utils/plot_script.py +++ b/data_loaders/humanml/utils/plot_script.py @@ -24,7 +24,7 @@ def list_cut_average(ll, intervals): return ll_new -def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(3, 3), fps=120, radius=3): +def plot_3d_motion(save_path, kinematic_tree, joints, title, dataset, figsize=(3, 3), fps=120, radius=3): matplotlib.use('Agg') # title_sp = title.split(' ') @@ -58,6 +58,15 @@ def plot_xzPlane(minx, maxx, miny, minz, maxz): # (seq_len, joints_num, 3) data = joints.copy().reshape(len(joints), -1, 3) + + # preparation related to specific datasets + if dataset == 'kit': + data *= 0.003 # scale for visualization + elif dataset == 'humanml': + data *= 1.3 # scale for visualization + elif dataset in ['humanact12', 'uestc']: + data *= -1.5 # reverse axes, scale for visualization + fig = plt.figure(figsize=figsize) plt.tight_layout() ax = p3.Axes3D(fig) diff --git a/data_loaders/tensors.py b/data_loaders/tensors.py index 30dba2bb..1b7ee69c 100644 --- a/data_loaders/tensors.py +++ b/data_loaders/tensors.py @@ -22,7 +22,6 @@ def collate_tensors(batch): def collate(batch): notnone_batches = [b for b in batch if b is not None] databatch = [b['inp'] for b in notnone_batches] - labelbatch = [b['target'] for b in notnone_batches] if 'lengths' in notnone_batches[0]: lenbatch = [b['lengths'] for b in notnone_batches] else: @@ -30,7 +29,6 @@ def collate(batch): databatchTensor = collate_tensors(databatch) - labelbatchTensor = torch.as_tensor(labelbatch) lenbatchTensor = torch.as_tensor(lenbatch) maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) # unqueeze for broadcasting @@ -45,6 +43,10 @@ def collate(batch): textbatch = [b['tokens'] for b in notnone_batches] cond['y'].update({'tokens': textbatch}) + if 'action' in notnone_batches[0]: + actionbatch = [b['action'] for b in notnone_batches] + cond['y'].update({'action': torch.as_tensor(actionbatch).unsqueeze(1)}) + # collate action textual names if 'action_text' in notnone_batches[0]: action_text = [b['action_text']for b in notnone_batches] diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index 42e178c5..fbfb3da0 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -130,7 +130,7 @@ def __init__( lambda_pose=1., lambda_orient=1., lambda_loc=1., - data_rep='rot', + data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., @@ -1310,7 +1310,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, data terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: - if self.data_rep == 'rot' and dataset.dataname in ['humanact12', 'uestc']: + if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) @@ -1319,7 +1319,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, data if self.lambda_fc > 0.: torch.autograd.set_detect_anomaly(True) - if self.data_rep == 'rot' and dataset.dataname in ['humanact12', 'uestc']: + if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 diff --git a/model/cfg_sampler.py b/model/cfg_sampler.py index 5161c823..88e12322 100644 --- a/model/cfg_sampler.py +++ b/model/cfg_sampler.py @@ -10,6 +10,9 @@ class ClassifierFreeSampleModel(nn.Module): def __init__(self, model): super().__init__() self.model = model # model is the actual model to run + + assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions' + # pointers to inner model self.rot2xyz = self.model.rot2xyz self.translation = self.model.translation @@ -25,4 +28,5 @@ def forward(self, x, timesteps, y=None): y_uncond['uncond'] = True out = self.model(x, timesteps, y) out_uncond = self.model(x, timesteps, y_uncond) - return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond)) \ No newline at end of file + return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond)) + diff --git a/model/mdm.py b/model/mdm.py index b6736fae..4739a8a7 100644 --- a/model/mdm.py +++ b/model/mdm.py @@ -10,7 +10,7 @@ class MDM(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, - ablation=None, activation="gelu", legacy=False, data_rep='rot', dataset='amass', clip_dim=512, + ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512, arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs): super().__init__() @@ -87,13 +87,7 @@ def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_re self.clip_version = clip_version self.clip_model = self.load_and_freeze_clip(clip_version) if 'action' in self.cond_mode: - if self.action_emb == 'scalar': - self.embed_action = EmbedActionScalar(in_features=1, out_features=self.latent_dim, - activation=self.activation) - elif self.action_emb == 'tensor': - self.embed_action = EmbedActionTensor(self.num_actions, self.latent_dim) - else: - raise Exception(f'Unknown action embedding {self.action_emb}.') + self.embed_action = EmbedAction(self.num_actions, self.latent_dim) print('EMBED ACTION') self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints, @@ -154,14 +148,13 @@ def forward(self, x, timesteps, y=None): bs, njoints, nfeats, nframes = x.shape emb = self.embed_timestep(timesteps) # [1, bs, d] + force_mask = y.get('uncond', False) if 'text' in self.cond_mode: enc_text = self.encode_text(y['text']) - force_mask = y.get('uncond', False) emb += self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)) if 'action' in self.cond_mode: - if not (y['action'] == -1).any(): # FIXME - a hack so we can use already trained models - action_emb = self.embed_action(y['action']) - emb += self.mask_cond(action_emb) + action_emb = self.embed_action(y['action']) + emb += self.mask_cond(action_emb, force_mask=force_mask) if self.arch == 'gru': x_reshaped = x.reshape(bs, njoints*nfeats, 1, nframes) @@ -197,15 +190,14 @@ def forward(self, x, timesteps, y=None): return output - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.rot2xyz.smpl_model.to(*args, **kwargs) + def _apply(self, fn): + super()._apply(fn) + self.rot2xyz.smpl_model._apply(fn) - def eval(self, *args, **kwargs): - super().eval(*args, **kwargs) - self.rot2xyz.smpl_model.eval(*args, **kwargs) - + def train(self, *args, **kwargs): + super().train(*args, **kwargs) + self.rot2xyz.smpl_model.train(*args, **kwargs) class PositionalEncoding(nn.Module): @@ -259,7 +251,7 @@ def forward(self, x): bs, njoints, nfeats, nframes = x.shape x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) - if self.data_rep in ['rot', 'xyz', 'hml_vec']: + if self.data_rep in ['rot6d', 'xyz', 'hml_vec']: x = self.poseEmbedding(x) # [seqlen, bs, d] return x elif self.data_rep == 'rot_vel': @@ -286,7 +278,7 @@ def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats): def forward(self, output): nframes, bs, d = output.shape - if self.data_rep in ['rot', 'xyz', 'hml_vec']: + if self.data_rep in ['rot6d', 'xyz', 'hml_vec']: output = self.poseFinal(output) # [seqlen, bs, 150] elif self.data_rep == 'rot_vel': first_pose = output[[0]] # [1, bs, d] @@ -301,24 +293,7 @@ def forward(self, output): return output -class EmbedActionScalar(nn.Module): - def __init__(self, in_features, out_features, activation): - super().__init__() - self.in_features = in_features - self.out_features = out_features - mid_features = int(out_features/2) - self.lin1 = nn.Linear(in_features, mid_features) - self.activation = eval(f'F.{activation}') - self.lin2 = nn.Linear(mid_features, out_features) - - def forward(self, input): - output = self.lin1(input) - output = self.activation(output) - output = self.lin2(output) - return output - - -class EmbedActionTensor(nn.Module): +class EmbedAction(nn.Module): def __init__(self, num_actions, latent_dim): super().__init__() self.action_embedding = nn.Parameter(torch.randn(num_actions, latent_dim)) diff --git a/prepare/download_a2m_datasets.sh b/prepare/download_a2m_datasets.sh new file mode 100644 index 00000000..fdfc7a8e --- /dev/null +++ b/prepare/download_a2m_datasets.sh @@ -0,0 +1,22 @@ +mkdir -p dataset/ +cd dataset/ + +echo "The datasets will be stored in the 'dataset' folder\n" + +# HumanAct12 poses +echo "Downloading the HumanAct12 poses dataset" +gdown "https://drive.google.com/uc?id=1130gHSvNyJmii7f6pv5aY5IyQIWc3t7R" +echo "Extracting the HumanAct12 poses dataset" +tar xfzv HumanAct12Poses.tar.gz +echo "Cleaning\n" +rm HumanAct12Poses.tar.gz + +# Donwload UESTC poses estimated with VIBE +echo "Downloading the UESTC poses estimated with VIBE" +gdown "https://drive.google.com/uc?id=1LE-EmYNzECU8o7A2DmqDKtqDMucnSJsy" +echo "Extracting the UESTC poses estimated with VIBE" +tar xjvf uestc.tar.bz2 +echo "Cleaning\n" +rm uestc.tar.bz2 + +echo "Downloading done!" diff --git a/sample.py b/sample.py index 45e14893..02b16fdd 100644 --- a/sample.py +++ b/sample.py @@ -28,7 +28,7 @@ def main(): max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60 fps = 12.5 if args.dataset == 'kit' else 20 n_frames = min(max_frames, int(args.motion_length*fps)) - is_using_data = args.input_text == '' and args.text_prompt == '' + is_using_data = not any([args.input_text, args.text_prompt, args.action_file, args.action_name]) dist_util.setup_dist(args.device) if out_path == '': out_path = os.path.join(os.path.dirname(args.model_path), @@ -38,23 +38,7 @@ def main(): elif args.input_text != '': out_path += '_' + os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '') - print("Creating model and diffusion...") - model, diffusion = create_model_and_diffusion(args) - - print(f"Loading checkpoints from [{args.model_path}]...") - state_dict = torch.load(args.model_path, map_location='cpu') - load_model_wo_clip(model, state_dict) - - if args.guidance_param != 1: - model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler - model.to(dist_util.dev()) - model.eval() # disable random masking - - - - - print('Loading dataset...') - texts = [] + # this block must be called BEFORE the dataset is loaded if args.text_prompt != '': texts = [args.text_prompt] args.num_samples = 1 @@ -64,6 +48,15 @@ def main(): texts = fr.readlines() texts = [s.replace('\n', '') for s in texts] args.num_samples = len(texts) + elif args.action_name: + action_text = [args.action_name] + args.num_samples = 1 + elif args.action_file != '': + assert os.path.exists(args.action_file) + with open(args.action_file, 'r') as fr: + action_text = fr.readlines() + action_text = [s.replace('\n', '') for s in action_text] + args.num_samples = len(action_text) assert args.num_samples <= args.batch_size, \ f'Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})' @@ -72,30 +65,45 @@ def main(): # If it doesn't, and you still want to sample more prompts, run this script with different seeds # (specify through the --seed flag) args.batch_size = args.num_samples # Sampling a single batch from the testset, with exactly args.num_samples - data = get_dataset_loader(name=args.dataset, - batch_size=args.batch_size, - num_frames=max_frames, - split='test', - hml_mode='text_only') - data.fixed_length = n_frames + + print('Loading dataset...') + data = load_dataset(args, max_frames, n_frames) total_num_samples = args.num_samples * args.num_repetitions + print("Creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, data) + + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + + if args.guidance_param != 1: + model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler + model.to(dist_util.dev()) + model.eval() # disable random masking if is_using_data: iterator = iter(data) _, model_kwargs = next(iterator) else: - _, model_kwargs = collate( - [{'inp': torch.tensor([[0.]]), 'target': 0, 'text': txt, 'tokens': None, 'lengths': n_frames} - for txt in texts] - ) + collate_args = [{'inp': torch.zeros(n_frames), 'target': 0, 'tokens': None, 'lengths': n_frames}] * args.num_samples + is_t2m = any([args.input_text, args.text_prompt]) + if is_t2m: + # t2m + collate_args = [dict(arg, text=txt) for arg, txt in zip(collate_args, texts)] + else: + # a2m + action = data.dataset.action_name_to_action(action_text) + collate_args = [dict(arg, action=one_action, action_text=one_action_text) for + arg, one_action, one_action_text in zip(collate_args, action, action_text)] + _, model_kwargs = collate(collate_args) all_motions = [] all_lengths = [] all_text = [] for rep_i in range(args.num_repetitions): - print(f'### Start sampling [repetitions #{rep_i}]') + print(f'### Sampling [repetitions #{rep_i}]') # add CFG scale to batch if args.guidance_param != 1: @@ -116,7 +124,6 @@ def main(): const_noise=False, ) - # Recover XYZ *positions* from HumanML3D vector representation if model.data_rep == 'hml_vec': n_joints = 22 if sample.shape[1] == 263 else 21 @@ -124,7 +131,15 @@ def main(): sample = recover_from_ric(sample, n_joints) sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) - all_text += model_kwargs['y']['text'] + rot2xyz_pose_rep = 'xyz' if model.data_rep in ['xyz', 'hml_vec'] else model.data_rep + rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.batch_size, n_frames).bool() + sample = model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True, + jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, + get_rotations_back=False) + + text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text' + all_text += model_kwargs['y'][text_key] + all_motions.append(sample.cpu().numpy()) all_lengths.append(model_kwargs['y']['lengths'].cpu().numpy()) @@ -164,11 +179,7 @@ def main(): save_file = 'sample{:02d}_rep{:02d}.mp4'.format(sample_i, rep_i) animation_save_path = os.path.join(out_path, save_file) print(f'[({sample_i}) "{caption}" | Rep #{rep_i} | -> {save_file}]') - if args.dataset == 'kit': - motion *= 0.003 # scale for visualization - elif args.dataset == 'humanml': - motion *= 1.3 # scale for visualization - plot_3d_motion(animation_save_path, skeleton, motion, title=caption, fps=fps) + plot_3d_motion(animation_save_path, skeleton, motion, dataset=args.dataset, title=caption, fps=fps) # Credit for visualization: https://github.com/EricGuo5513/text-to-motion rep_files.append(animation_save_path) @@ -194,5 +205,15 @@ def main(): print(f'[Done] Results are at [{abs_path}]') +def load_dataset(args, max_frames, n_frames): + data = get_dataset_loader(name=args.dataset, + batch_size=args.batch_size, + num_frames=max_frames, + split='test', + hml_mode='text_only') + data.fixed_length = n_frames + return data + + if __name__ == "__main__": main() diff --git a/train/train_mdm.py b/train/train_mdm.py index 3374cada..adeabe7f 100644 --- a/train/train_mdm.py +++ b/train/train_mdm.py @@ -36,7 +36,7 @@ def main(): data = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=args.num_frames) print("creating model and diffusion...") - model, diffusion = create_model_and_diffusion(args) + model, diffusion = create_model_and_diffusion(args, data) model.to(dist_util.dev()) model.rot2xyz.smpl_model.eval() diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 00000000..abe0cdc3 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,40 @@ +import torch + + +def to_numpy(tensor): + if torch.is_tensor(tensor): + return tensor.cpu().numpy() + elif type(tensor).__module__ != 'numpy': + raise ValueError("Cannot convert {} to numpy array".format( + type(tensor))) + return tensor + + +def to_torch(ndarray): + if type(ndarray).__module__ == 'numpy': + return torch.from_numpy(ndarray) + elif not torch.is_tensor(ndarray): + raise ValueError("Cannot convert {} to torch tensor".format( + type(ndarray))) + return ndarray + + +def cleanexit(): + import sys + import os + try: + sys.exit(0) + except SystemExit: + os._exit(0) + +def load_model_wo_clip(model, state_dict): + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + assert len(unexpected_keys) == 0 + assert all([k.startswith('clip_model.') for k in missing_keys]) + +def freeze_joints(x, joints_to_freeze): + # Freezes selected joint *rotations* as they appear in the first frame + # x [bs, [root+n_joints], joint_dim(6), seqlen] + frozen = x.detach().clone() + frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] + return frozen diff --git a/utils/model_util.py b/utils/model_util.py index 1e508196..f4ef3516 100644 --- a/utils/model_util.py +++ b/utils/model_util.py @@ -9,19 +9,22 @@ def load_model_wo_clip(model, state_dict): assert all([k.startswith('clip_model.') for k in missing_keys]) -def create_model_and_diffusion(args): - model = MDM(**get_model_args(args)) +def create_model_and_diffusion(args, data): + model = MDM(**get_model_args(args, data)) diffusion = create_gaussian_diffusion(args) return model, diffusion -def get_model_args(args): +def get_model_args(args, data): # default args clip_version = 'ViT-B/32' action_emb = 'tensor' cond_mode = 'text' if args.dataset in ['kit', 'humanml'] else 'action' - num_actions = 1 + if hasattr(data.dataset, 'num_actions'): + num_actions = data.dataset.num_actions + else: + num_actions = 1 # SMPL defaults data_rep = 'rot6d' @@ -37,8 +40,6 @@ def get_model_args(args): njoints = 251 nfeats = 1 - # TODO - set num_actions for action-to-motion datasets - return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions, 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True, 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4, diff --git a/utils/parser_util.py b/utils/parser_util.py index dc7a6bc3..0e31fd61 100644 --- a/utils/parser_util.py +++ b/utils/parser_util.py @@ -26,6 +26,8 @@ def parse_and_load_from_model(parser): args.__dict__[a] = model_args[a] else: print('Warning: was not able to load [{}], using default value [{}] instead.'.format(a, args.__dict__[a])) + if args.cond_mask_prob == 0: + args.guidance_param = 1 return args @@ -135,8 +137,14 @@ def add_sampling_options(parser): "If empty, will create dir in parallel to checkpoint.") group.add_argument("--input_text", default='', type=str, help="Path to csv/txt file that specifies generation. If empty, will take text prompts from dataset.") + group.add_argument("--action_file", default='', type=str, + help="Path to a text file that lists names of actions to be synthesized. Names must be a subset of dataset/uestc/info/action_classes.txt if sampling from uestc, " + "or a subset of [warm_up,walk,run,jump,drink,lift_dumbbell,sit,eat,turn steering wheel,phone,boxing,throw] if sampling from humanact12. " + "If empty, will take text prompts from dataset.") group.add_argument("--text_prompt", default='', type=str, help="A text prompt to be generated. If empty, will take text prompts from dataset.") + group.add_argument("--action_name", default='', type=str, + help="An action name to be generated. If empty, will take text prompts from dataset.") group.add_argument("--num_samples", default=10, type=int, help="Maximal number of prompts to sample, " "if loading dataset from file, this field will be ignored.")