Skip to content

Commit

Permalink
support a2m sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
sigal-raab committed Oct 31, 2022
1 parent b0f9721 commit bb8960a
Show file tree
Hide file tree
Showing 18 changed files with 771 additions and 95 deletions.
72 changes: 68 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ bash prepare/download_glove.sh

### 2. Get data

<details>
<summary><b>Text to Motion</b></summary>

There are two paths to get the data:

(a) **Go the easy way if** you just want to generate text-to-motion (excluding editing which does require motion capture data)
Expand Down Expand Up @@ -102,11 +105,25 @@ cp -r ../HumanML3D/HumanML3D ./dataset/HumanML3D
```

**KIT** - Download from [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git) (no processing needed this time) and the place result in `./dataset/KIT-ML`
</details>

<details>
<summary><b>Action to Motion</b></summary>

**UESTC, HumanAct12** :
```bash
bash prepare/download_a2m_datasets.sh
```
</details>

### 3. Download the pretrained models

Download the model(s) you wish to use, then unzip and place it in `./save/`. **For text-to-motion, you need only the first one.**
Download the model(s) you wish to use, then unzip and place them in `./save/`.

<details>
<summary><b>Text to Motion</b></summary>

**You need only the first one.**

**HumanML3D**

Expand All @@ -120,7 +137,29 @@ Download the model(s) you wish to use, then unzip and place it in `./save/`. **F

[kit-encoder-512](https://drive.google.com/file/d/1SHCRcE0es31vkJMLGf9dyLe7YsWj7pNL/view?usp=sharing)

## Generate text-to-motion
</details>

<details>
<summary><b>Action to Motion</b></summary>

**UESTC**

[uestc](https://drive.google.com/file/d/1goB2DJK4B-fLu2QmqGWKAqWGMTAO6wQ6/view?usp=sharing)

[uestc_no_fc](https://drive.google.com/file/d/1fpv3mR-qP9CYCsi9CrQhFqlLavcSQky6/view?usp=sharing)

**HumanAct12**

[humanact12](https://drive.google.com/file/d/154X8_Lgpec6Xj0glEGql7FVKqPYCdBFO/view?usp=sharing)

[humanact12_no_fc](https://drive.google.com/file/d/1frKVMBYNiN5Mlq7zsnhDBzs9vGJvFeiQ/view?usp=sharing)

</details>


## Motion Synthesis
<details>
<summary><b>Text to Motion</b></summary>

### Generate from test set prompts

Expand All @@ -139,11 +178,35 @@ python -m sample --model_path ./save/humanml_trans_enc_512/model000200000.pt --i
```shell
python -m sample --model_path ./save/humanml_trans_enc_512/model000200000.pt --text_prompt "the person walked forward and is picking up his toolbox."
```
</details>

<details>
<summary><b>Action to Motion</b></summary>

### Generate from test set actions

```shell
python -m sample --model_path ./save/humanact12/model000350000.pt --num_samples 10 --num_repetitions 3
```

### Generate from your actions file

```shell
python -m sample --model_path ./save/humanact12/model000350000.pt --action_file ./assets/example_action_names_humanact12.txt
```

### Generate a single prompt

```shell
python -m sample --model_path ./save/humanact12/model000350000.pt --text_prompt "drink"
```
</details>


**You can also define:**
**You may also define:**
* `--device` id.
* `--seed` to sample different prompts.
* `--motion_length` in seconds (maximum is 9.8[sec]).
* `--motion_length` (text-to-motion only) in seconds (maximum is 9.8[sec]).

**Running those will get you:**

Expand Down Expand Up @@ -191,6 +254,7 @@ python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset hum
```shell
python -m train.train_mdm --save_dir save/my_kit_trans_enc_512 --dataset kit
```
</details>

* Use `--device` to define GPU id.
* Use `--arch` to choose one of the architectures reported in the paper `{trans_enc, trans_dec, gru}` (`trans_enc` is default).
Expand Down
2 changes: 2 additions & 0 deletions assets/example_action_names_humanact12.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
drink
lift_dumbbell
7 changes: 7 additions & 0 deletions assets/example_action_names_uestc.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
jumping-jack
left-lunging
left-stretching
raising-hand-and-jumping
rotation-clapping
front-raising
pulling-chest-expanders
237 changes: 237 additions & 0 deletions data_loaders/a2m/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import random

import numpy as np
import torch
# from utils.action_label_to_idx import action_label_to_idx
from data_loaders.tensors import collate
from utils.misc import to_torch
import utils.rotation_conversions as geometry

class Dataset(torch.utils.data.Dataset):
def __init__(self, num_frames=1, sampling="conseq", sampling_step=1, split="train",
pose_rep="rot6d", translation=True, glob=True, max_len=-1, min_len=-1, num_seq_max=-1, **kwargs):
self.num_frames = num_frames
self.sampling = sampling
self.sampling_step = sampling_step
self.split = split
self.pose_rep = pose_rep
self.translation = translation
self.glob = glob
self.max_len = max_len
self.min_len = min_len
self.num_seq_max = num_seq_max

self.align_pose_frontview = kwargs.get('align_pose_frontview', False)
self.use_action_cat_as_text_labels = kwargs.get('use_action_cat_as_text_labels', False)
self.only_60_classes = kwargs.get('only_60_classes', False)
self.leave_out_15_classes = kwargs.get('leave_out_15_classes', False)
self.use_only_15_classes = kwargs.get('use_only_15_classes', False)

if self.split not in ["train", "val", "test"]:
raise ValueError(f"{self.split} is not a valid split")

super().__init__()

# to remove shuffling
self._original_train = None
self._original_test = None

def action_to_label(self, action):
return self._action_to_label[action]

def label_to_action(self, label):
import numbers
if isinstance(label, numbers.Integral):
return self._label_to_action[label]
else: # if it is one hot vector
label = np.argmax(label)
return self._label_to_action[label]

def get_pose_data(self, data_index, frame_ix):
pose = self._load(data_index, frame_ix)
label = self.get_label(data_index)
return pose, label

def get_label(self, ind):
action = self.get_action(ind)
return self.action_to_label(action)

def get_action(self, ind):
return self._actions[ind]

def action_to_action_name(self, action):
return self._action_classes[action]

def action_name_to_action(self, action_name):
# self._action_classes is either a list or a dictionary. If it's a dictionary, we 1st convert it to a list
all_action_names = self._action_classes
if isinstance(all_action_names, dict):
all_action_names = list(all_action_names.values())
assert list(self._action_classes.keys()) == list(range(len(all_action_names))) # the keys should be ordered from 0 to num_actions

sorter = np.argsort(all_action_names)
actions = sorter[np.searchsorted(all_action_names, action_name, sorter=sorter)]
return actions

def __getitem__(self, index):
if self.split == 'train':
data_index = self._train[index]
else:
data_index = self._test[index]

# inp, target = self._get_item_data_index(data_index)
# return inp, target
return self._get_item_data_index(data_index)

def _load(self, ind, frame_ix):
pose_rep = self.pose_rep
if pose_rep == "xyz" or self.translation:
if getattr(self, "_load_joints3D", None) is not None:
# Locate the root joint of initial pose at origin
joints3D = self._load_joints3D(ind, frame_ix)
joints3D = joints3D - joints3D[0, 0, :]
ret = to_torch(joints3D)
if self.translation:
ret_tr = ret[:, 0, :]
else:
if pose_rep == "xyz":
raise ValueError("This representation is not possible.")
if getattr(self, "_load_translation") is None:
raise ValueError("Can't extract translations.")
ret_tr = self._load_translation(ind, frame_ix)
ret_tr = to_torch(ret_tr - ret_tr[0])

if pose_rep != "xyz":
if getattr(self, "_load_rotvec", None) is None:
raise ValueError("This representation is not possible.")
else:
pose = self._load_rotvec(ind, frame_ix)
if not self.glob:
pose = pose[:, 1:, :]
pose = to_torch(pose)
if self.align_pose_frontview:
first_frame_root_pose_matrix = geometry.axis_angle_to_matrix(pose[0][0])
all_root_poses_matrix = geometry.axis_angle_to_matrix(pose[:, 0, :])
aligned_root_poses_matrix = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1),
all_root_poses_matrix)
pose[:, 0, :] = geometry.matrix_to_axis_angle(aligned_root_poses_matrix)

if self.translation:
ret_tr = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1).float(),
torch.transpose(ret_tr, 0, 1))
ret_tr = torch.transpose(ret_tr, 0, 1)

if pose_rep == "rotvec":
ret = pose
elif pose_rep == "rotmat":
ret = geometry.axis_angle_to_matrix(pose).view(*pose.shape[:2], 9)
elif pose_rep == "rotquat":
ret = geometry.axis_angle_to_quaternion(pose)
elif pose_rep == "rot6d":
ret = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(pose))
if pose_rep != "xyz" and self.translation:
padded_tr = torch.zeros((ret.shape[0], ret.shape[2]), dtype=ret.dtype)
padded_tr[:, :3] = ret_tr
ret = torch.cat((ret, padded_tr[:, None]), 1)
ret = ret.permute(1, 2, 0).contiguous()
return ret.float()

def _get_item_data_index(self, data_index):
nframes = self._num_frames_in_video[data_index]

if self.num_frames == -1 and (self.max_len == -1 or nframes <= self.max_len):
frame_ix = np.arange(nframes)
else:
if self.num_frames == -2:
if self.min_len <= 0:
raise ValueError("You should put a min_len > 0 for num_frames == -2 mode")
if self.max_len != -1:
max_frame = min(nframes, self.max_len)
else:
max_frame = nframes

num_frames = random.randint(self.min_len, max(max_frame, self.min_len))
else:
num_frames = self.num_frames if self.num_frames != -1 else self.max_len

if num_frames > nframes:
fair = False # True
if fair:
# distills redundancy everywhere
choices = np.random.choice(range(nframes),
num_frames,
replace=True)
frame_ix = sorted(choices)
else:
# adding the last frame until done
ntoadd = max(0, num_frames - nframes)
lastframe = nframes - 1
padding = lastframe * np.ones(ntoadd, dtype=int)
frame_ix = np.concatenate((np.arange(0, nframes),
padding))

elif self.sampling in ["conseq", "random_conseq"]:
step_max = (nframes - 1) // (num_frames - 1)
if self.sampling == "conseq":
if self.sampling_step == -1 or self.sampling_step * (num_frames - 1) >= nframes:
step = step_max
else:
step = self.sampling_step
elif self.sampling == "random_conseq":
step = random.randint(1, step_max)

lastone = step * (num_frames - 1)
shift_max = nframes - lastone - 1
shift = random.randint(0, max(0, shift_max - 1))
frame_ix = shift + np.arange(0, lastone + 1, step)

elif self.sampling == "random":
choices = np.random.choice(range(nframes),
num_frames,
replace=False)
frame_ix = sorted(choices)

else:
raise ValueError("Sampling not recognized.")

inp, action = self.get_pose_data(data_index, frame_ix)


output = {'inp': inp, 'action': action}

if hasattr(self, '_actions') and hasattr(self, '_action_classes'):
output['action_text'] = self.action_to_action_name(self.get_action(data_index))

return output


def get_mean_length_label(self, label):
if self.num_frames != -1:
return self.num_frames

if self.split == 'train':
index = self._train
else:
index = self._test

action = self.label_to_action(label)
choices = np.argwhere(self._actions[index] == action).squeeze(1)
lengths = self._num_frames_in_video[np.array(index)[choices]]

if self.max_len == -1:
return np.mean(lengths)
else:
# make the lengths less than max_len
lengths[lengths > self.max_len] = self.max_len
return np.mean(lengths)

def __len__(self):
num_seq_max = getattr(self, "num_seq_max", -1)
if num_seq_max == -1:
from math import inf
num_seq_max = inf

if self.split == 'train':
return min(len(self._train), num_seq_max)
else:
return min(len(self._test), num_seq_max)
Loading

0 comments on commit bb8960a

Please sign in to comment.