Skip to content

Commit

Permalink
support a2m training and evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
sigal-raab committed Oct 31, 2022
1 parent bb8960a commit ce26c8e
Show file tree
Hide file tree
Showing 26 changed files with 1,474 additions and 22 deletions.
31 changes: 30 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ If you find this code useful in your research, please cite:

## News

📢 **31/Oct/22** - Added sampling, training and evaluation of action-to-motion tasks.

📢 **9/Oct/22** - Added training and evaluation scripts.
Note slight env changes adapting to the new code. If you already have an installed environment, run `bash prepare/download_glove.sh; pip install clearml` to adapt.

Expand All @@ -32,7 +34,6 @@ If you find this code useful in your research, please cite:
## ETAs

* Editing: Nov 22
* Action to Motion: Nov 22
* Unconstrained Motion: Nov 22


Expand Down Expand Up @@ -245,6 +246,9 @@ python -m visualize.render_mesh --input_path /path/to/mp4/stick/figure/file

## Train your own MDM

<details>
<summary><b>Text to Motion</b></summary>

**HumanML3D**
```shell
python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset humanml
Expand All @@ -255,6 +259,13 @@ python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset hum
python -m train.train_mdm --save_dir save/my_kit_trans_enc_512 --dataset kit
```
</details>
<details>
<summary><b>Action to Motion</b></summary>

```shell
python -m train.train_mdm --save_dir save/my_name --dataset {humanact12,uestc} --cond_mask_prob 0 --lambda_rcxyz 1 --lambda_vel 1 --lambda_fc 1
```
</details>

* Use `--device` to define GPU id.
* Use `--arch` to choose one of the architectures reported in the paper `{trans_enc, trans_dec, gru}` (`trans_enc` is default).
Expand All @@ -263,6 +274,10 @@ python -m train.train_mdm --save_dir save/my_kit_trans_enc_512 --dataset kit
This will slow down training but will give you better monitoring.

## Evaluate

<details>
<summary><b>Text to Motion</b></summary>

* Takes about 20 hours (on a single GPU)
* The output of this script for the pre-trained models (as was reported in the paper) is provided in the checkpoints zip file.

Expand All @@ -275,6 +290,20 @@ python -m eval.eval_humanml --model_path ./save/humanml_trans_enc_512/model00047
```shell
python -m eval.eval_humanml --model_path ./save/kit_trans_enc_512/model000400000.pt
```
</details>

<details>
<summary><b>Action to Motion</b></summary>

* Takes about 7 hours for UESTC and 2 hours for HumanAct12 (on a single GPU)
* The output of this script for the pre-trained models (as was reported in the paper) is provided in the checkpoints zip file.

```shell
--model <path-to-model-ckpt> --eval_mode full
```
where `path-to-model-ckpt` can be a path to any of the pretrained action-to-motion models listed above, or to a checkpoint trained by the user.

</details>


## Acknowledgments
Expand Down
18 changes: 18 additions & 0 deletions data_loaders/a2m/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,21 @@ def __len__(self):
return min(len(self._train), num_seq_max)
else:
return min(len(self._test), num_seq_max)

def shuffle(self):
if self.split == 'train':
random.shuffle(self._train)
else:
random.shuffle(self._test)

def reset_shuffle(self):
if self.split == 'train':
if self._original_train is None:
self._original_train = self._train
else:
self._train = self._original_train
else:
if self._original_test is None:
self._original_test = self._test
else:
self._test = self._original_test
1 change: 0 additions & 1 deletion data_loaders/a2m/uestc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .dataset import Dataset
# from torch.utils.data import Dataset

# from .ntu13 import action2motion_joints
action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38]


Expand Down
2 changes: 1 addition & 1 deletion data_loaders/get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='tr
collate = get_collate_fn(name, hml_mode)

loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, #(split == 'train'),
dataset, batch_size=batch_size, shuffle=True,
num_workers=8, drop_last=True, collate_fn=collate
)

Expand Down
1 change: 0 additions & 1 deletion data_loaders/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def t2m_collate(batch):
# batch.sort(key=lambda x: x[3], reverse=True)
adapted_batch = [{
'inp': torch.tensor(b[4].T).float().unsqueeze(1), # [seqlen, J] -> [J, 1, seqlen]
'target': 0,
'text': b[2], #b[0]['caption']
'tokens': b[6],
'lengths': b[5],
Expand Down
Empty file added eval/a2m/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions eval/a2m/action2motion/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch


def calculate_accuracy(model, motion_loader, num_labels, classifier, device):
confusion = torch.zeros(num_labels, num_labels, dtype=torch.long)
with torch.no_grad():
for batch in motion_loader:
batch_prob = classifier(batch["output_xyz"], lengths=batch["lengths"])
batch_pred = batch_prob.max(dim=1).indices
for label, pred in zip(batch["y"], batch_pred):
confusion[label][pred] += 1

accuracy = torch.trace(confusion)/torch.sum(confusion)
return accuracy.item(), confusion
51 changes: 51 additions & 0 deletions eval/a2m/action2motion/diversity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import numpy as np


# from action2motion
def calculate_diversity_multimodality(activations, labels, num_labels, unconstrained = False):
diversity_times = 200
multimodality_times = 20
if not unconstrained:
labels = labels.long()
num_motions = activations.shape[0] # len(labels)

diversity = 0

first_indices = np.random.randint(0, num_motions, diversity_times)
second_indices = np.random.randint(0, num_motions, diversity_times)
for first_idx, second_idx in zip(first_indices, second_indices):
diversity += torch.dist(activations[first_idx, :],
activations[second_idx, :])
diversity /= diversity_times

if not unconstrained:
multimodality = 0
label_quotas = np.zeros(num_labels)
label_quotas[labels.unique()] = multimodality_times # if a label does not appear in batch, its quota remains zero
while np.any(label_quotas > 0):
# print(label_quotas)
first_idx = np.random.randint(0, num_motions)
first_label = labels[first_idx]
if not label_quotas[first_label]:
continue

second_idx = np.random.randint(0, num_motions)
second_label = labels[second_idx]
while first_label != second_label:
second_idx = np.random.randint(0, num_motions)
second_label = labels[second_idx]

label_quotas[first_label] -= 1

first_activation = activations[first_idx, :]
second_activation = activations[second_idx, :]
multimodality += torch.dist(first_activation,
second_activation)

multimodality /= (multimodality_times * num_labels)
else:
multimodality = torch.tensor(np.nan)

return diversity.item(), multimodality.item()

84 changes: 84 additions & 0 deletions eval/a2m/action2motion/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import numpy as np
from .models import load_classifier, load_classifier_for_fid
from .accuracy import calculate_accuracy
from .fid import calculate_fid
from .diversity import calculate_diversity_multimodality


class A2MEvaluation:
def __init__(self, device):
dataset_opt = {"input_size_raw": 72, "joints_num": 24, "num_classes": 12}

self.input_size_raw = dataset_opt["input_size_raw"]
self.num_classes = dataset_opt["num_classes"]
self.device = device

self.gru_classifier_for_fid = load_classifier_for_fid(self.input_size_raw, self.num_classes, device).eval()
self.gru_classifier = load_classifier(self.input_size_raw, self.num_classes, device).eval()

def compute_features(self, model, motionloader):
# calculate_activations_labels function from action2motion
activations = []
labels = []
with torch.no_grad():
for idx, batch in enumerate(motionloader):
activations.append(self.gru_classifier_for_fid(batch["output_xyz"], lengths=batch["lengths"]))
if model.cond_mode != 'no_cond':
labels.append(batch["y"])
activations = torch.cat(activations, dim=0)
if model.cond_mode != 'no_cond':
labels = torch.cat(labels, dim=0)
return activations, labels

@staticmethod
def calculate_activation_statistics(activations):
activations = activations.cpu().numpy()
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
return mu, sigma

def evaluate(self, model, loaders):

def print_logs(metric, key):
print(f"Computing action2motion {metric} on the {key} loader ...")

metrics = {}

computedfeats = {}
for key, loader in loaders.items():
metric = "accuracy"
print_logs(metric, key)
mkey = f"{metric}_{key}"
if model.cond_mode != 'no_cond':
metrics[mkey], _ = calculate_accuracy(model, loader,
self.num_classes,
self.gru_classifier, self.device)
else:
metrics[mkey] = np.nan

# features for diversity
print_logs("features", key)
feats, labels = self.compute_features(model, loader)
print_logs("stats", key)
stats = self.calculate_activation_statistics(feats)

computedfeats[key] = {"feats": feats,
"labels": labels,
"stats": stats}

print_logs("diversity", key)
ret = calculate_diversity_multimodality(feats, labels, self.num_classes, unconstrained=(model.cond_mode=='no_cond'))
metrics[f"diversity_{key}"], metrics[f"multimodality_{key}"] = ret

# taking the stats of the ground truth and remove it from the computed feats
gtstats = computedfeats["gt"]["stats"]
# computing fid
for key, loader in computedfeats.items():
metric = "fid"
mkey = f"{metric}_{key}"

stats = computedfeats[key]["stats"]
metrics[mkey] = float(calculate_fid(gtstats, stats))

return metrics
61 changes: 61 additions & 0 deletions eval/a2m/action2motion/fid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
from scipy import linalg


# from action2motion
def calculate_fid(statistics_1, statistics_2):
return calculate_frechet_distance(statistics_1[0], statistics_1[1],
statistics_2[0], statistics_2[1])


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representative data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representative data set.
Returns:
-- : The Frechet Distance.
"""

mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)

sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)

assert mu1.shape == mu2.shape, \
'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, \
'Training and test covariances have different dimensions'

diff = mu1 - mu2

# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real

tr_covmean = np.trace(covmean)

return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
Loading

0 comments on commit ce26c8e

Please sign in to comment.