diff --git a/README.md b/README.md
index 329f2ff3..9b6fe304 100644
--- a/README.md
+++ b/README.md
@@ -24,6 +24,10 @@ If you find this code useful in your research, please cite:
## News
+📢 **4/Nov/22** - Added sampling, training and evaluation of unconstrained tasks.
+ Note slight env changes adapting to the new code. If you already have an installed environment, run `bash prepare/download_unconstrained_assets.sh; conda install -y -c anaconda scikit-learn
+` to adapt.
+
📢 **3/Nov/22** - Added in-between and upper-body editing.
📢 **31/Oct/22** - Added sampling, training and evaluation of action-to-motion tasks.
@@ -33,10 +37,6 @@ If you find this code useful in your research, please cite:
📢 **6/Oct/22** - First release - sampling and rendering using pre-trained models.
-## ETAs
-
-* Unconstrained Motion: Nov 22
-
## Getting started
@@ -76,15 +76,23 @@ bash prepare/download_glove.sh
- Text to Motion, Unconstrained
+ Action to Motion
```bash
bash prepare/download_smpl_files.sh
-bash prepare/download_a2m_datasets.sh
bash prepare/download_recognition_models.sh
```
+
+ Unconstrained
+
+```bash
+bash prepare/download_smpl_files.sh
+bash prepare/download_recognition_unconstrained_models.sh
+```
+
+
### 2. Get data
@@ -125,12 +133,21 @@ cp -r ../HumanML3D/HumanML3D ./dataset/HumanML3D
Action to Motion
-**UESTC, HumanAct12** :
+**UESTC, HumanAct12**
```bash
bash prepare/download_a2m_datasets.sh
```
+
+ Unconstrained
+
+**HumanAct12**
+```bash
+bash prepare/download_unconstrained_datasets.sh
+```
+
+
### 3. Download the pretrained models
Download the model(s) you wish to use, then unzip and place them in `./save/`.
@@ -171,6 +188,15 @@ Download the model(s) you wish to use, then unzip and place them in `./save/`.
+
+ Unconstrained
+
+**HumanAct12**
+
+[humanact12_unconstrained](https://drive.google.com/file/d/1uG68m200pZK3pD-zTmPXu5XkgNpx_mEx/view?usp=share_link)
+
+
+
## Motion Synthesis
@@ -217,6 +243,16 @@ python -m sample.generate --model_path ./save/humanact12/model000350000.pt --tex
```
+
+ Unconstrained
+
+```shell
+python -m sample.generate --model_path ./save/unconstrained/model000450000.pt --num_samples 10 --num_repetitions 3
+```
+
+By abuse of notation, (num_samples * num_repetitions) samples are created, and are visually organized in a display of num_samples rows and num_repetitions columns.
+
+
**You may also define:**
* `--device` id.
@@ -317,6 +353,14 @@ python -m train.train_mdm --save_dir save/my_name --dataset {humanact12,uestc} -
```
+
+ Unconstrained
+
+```shell
+python -m train.train_mdm --save_dir save/my_name --dataset humanact12 --cond_mask_prob 0 --lambda_rcxyz 1 --lambda_vel 1 --lambda_fc 1 --unconstrained
+```
+
+
* Use `--device` to define GPU id.
* Use `--arch` to choose one of the architectures reported in the paper `{trans_enc, trans_dec, gru}` (`trans_enc` is default).
* Add `--train_platform_type {ClearmlPlatform, TensorboardPlatform}` to track results with either [ClearML](https://clear.ml/) or [Tensorboard](https://www.tensorflow.org/tensorboard).
@@ -349,19 +393,32 @@ python -m eval.eval_humanml --model_path ./save/kit_trans_enc_512/model000400000
* The output of this script for the pre-trained models (as was reported in the paper) is provided in the checkpoints zip file.
```shell
---model --eval_mode full
+python -m eval.eval_humanact12_uestc --model --eval_mode full
```
where `path-to-model-ckpt` can be a path to any of the pretrained action-to-motion models listed above, or to a checkpoint trained by the user.
+
+ Unconstrained
+
+* Takes about 3 hours (on a single GPU)
+
+```shell
+python -m eval.eval_humanact12_uestc --model ./save/unconstrained/model000450000.pt --eval_mode full
+```
+
+Precision and recall are not computed to save computing time. If you wish to compute them, edit the file eval/a2m/gru_eval.py and change the string `fast=True` to `fast=False`.
+
+
+
## Acknowledgments
This code is standing on the shoulders of giants. We want to thank the following contributors
that our code is based on:
-[guided-diffusion](https://github.com/openai/guided-diffusion), [MotionCLIP](https://github.com/GuyTevet/MotionCLIP), [text-to-motion](https://github.com/EricGuo5513/text-to-motion), [actor](https://github.com/Mathux/ACTOR), [joints2smpl](https://github.com/wangsen1312/joints2smpl).
+[guided-diffusion](https://github.com/openai/guided-diffusion), [MotionCLIP](https://github.com/GuyTevet/MotionCLIP), [text-to-motion](https://github.com/EricGuo5513/text-to-motion), [actor](https://github.com/Mathux/ACTOR), [joints2smpl](https://github.com/wangsen1312/joints2smpl), [MoDi](https://github.com/sigal-raab/MoDi).
## License
This code is distributed under an [MIT LICENSE](LICENSE).
diff --git a/environment.yml b/environment.yml
index 11092210..dceaae7a 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,6 +1,7 @@
name: mdm
channels:
- pytorch
+ - anaconda
- conda-forge
- defaults
dependencies:
@@ -9,9 +10,9 @@ dependencies:
- beautifulsoup4=4.11.1=pyha770c72_0
- blas=1.0=mkl
- brotlipy=0.7.0=py37h540881e_1004
- - ca-certificates=2022.9.24=ha878542_0
+ - ca-certificates=2022.07.19=h06a4308_0
- catalogue=2.0.8=py37h89c1867_0
- - certifi=2022.9.24=pyhd8ed1ab_0
+ - certifi=2022.6.15=py37h06a4308_0
- cffi=1.15.1=py37h74dc2b5_0
- charset-normalizer=2.1.1=pyhd8ed1ab_0
- colorama=0.4.5=pyhd8ed1ab_0
@@ -37,6 +38,7 @@ dependencies:
- idna=3.4=pyhd8ed1ab_0
- intel-openmp=2021.4.0=h06a4308_3561
- jinja2=3.1.2=pyhd8ed1ab_1
+ - joblib=1.1.0=pyhd3eb1b0_0
- jpeg=9b=h024ee3a_2
- kiwisolver=1.4.2=py37h295c915_0
- langcodes=3.3.0=pyhd8ed1ab_0
@@ -87,6 +89,7 @@ dependencies:
- qt=5.9.7=h5867ecd_1
- readline=8.1.2=h7f8727e_1
- requests=2.28.1=pyhd8ed1ab_1
+ - scikit-learn=1.0.2=py37h51133e4_1
- scipy=1.7.3=py37h6c91a56_2
- setuptools=63.4.1=py37h06a4308_0
- shellingham=1.5.0=pyhd8ed1ab_0
@@ -98,6 +101,7 @@ dependencies:
- spacy-legacy=3.0.10=pyhd8ed1ab_0
- spacy-loggers=1.0.3=pyhd8ed1ab_0
- sqlite=3.39.3=h5082296_0
+ - threadpoolctl=2.2.0=pyh0d69192_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=0.7.2=py37
- torchvision=0.8.2=py37_cu110
@@ -113,10 +117,8 @@ dependencies:
- pip:
- blis==0.7.8
- chumpy==0.70
- - clearml==1.7.1
- click==8.1.3
- confection==0.0.2
- - filelock==3.8.0
- ftfy==6.1.1
- importlib-metadata==5.0.0
- lxml==4.9.1
diff --git a/eval/a2m/action2motion/diversity.py b/eval/a2m/action2motion/diversity.py
index 5ebe8a45..c2011080 100644
--- a/eval/a2m/action2motion/diversity.py
+++ b/eval/a2m/action2motion/diversity.py
@@ -2,6 +2,21 @@
import numpy as np
+#adapted from action2motion
+def calculate_diversity(activations):
+ diversity_times = 200
+ num_motions = len(activations)
+
+ diversity = 0
+
+ first_indices = np.random.randint(0, num_motions, diversity_times)
+ second_indices = np.random.randint(0, num_motions, diversity_times)
+ for first_idx, second_idx in zip(first_indices, second_indices):
+ diversity += torch.dist(activations[first_idx, :],
+ activations[second_idx, :])
+ diversity /= diversity_times
+ return diversity
+
# from action2motion
def calculate_diversity_multimodality(activations, labels, num_labels, unconstrained = False):
diversity_times = 200
diff --git a/eval/a2m/gru_eval.py b/eval/a2m/gru_eval.py
index 1595f0b9..32267399 100644
--- a/eval/a2m/gru_eval.py
+++ b/eval/a2m/gru_eval.py
@@ -1,5 +1,7 @@
import copy
import os
+
+import numpy as np
from tqdm import tqdm
import torch
import functools
@@ -8,13 +10,14 @@
from utils.fixseed import fixseed
from data_loaders.tensors import collate
from eval.a2m.action2motion.evaluate import A2MEvaluation
+from eval.unconstrained.evaluate import evaluate_unconstrained_metrics
from .tools import save_metrics, format_metrics
-from data_loaders.get_data import get_dataset
from utils import dist_util
+num_samples_unconstrained = 1000
class NewDataloader:
- def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, dataset, num_samples: int=-1):
+ def __init__(self, mode, model, diffusion, dataiterator, device, unconstrained, num_samples: int=-1):
assert mode in ["gen", "gt"]
self.batches = []
sample_fn = diffusion.p_sample_loop
@@ -37,7 +40,7 @@ def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, data
translation=True, jointstype='smpl', vertstrans=True, betas=None,
beta=0, glob_rot=None, get_rotations_back=False)
batch["lengths"] = model_kwargs['y']['lengths'].to(device)
- if cond_mode != 'no_cond': # proceed only if not running unconstrained
+ if unconstrained: # proceed only if not running unconstrained
batch["y"] = model_kwargs['y']['action'].squeeze().long().cpu() # using torch.long so lengths/action will be used as indices
self.batches.append(batch)
@@ -49,7 +52,6 @@ def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, data
def __iter__(self):
return iter(self.batches)
-
def evaluate(args, model, diffusion, data):
num_frames = 60
@@ -87,25 +89,43 @@ def evaluate(args, model, diffusion, data):
shuffle=False, num_workers=8, collate_fn=collate)
new_data_loader = functools.partial(NewDataloader, model=model, diffusion=diffusion, device=device,
- cond_mode=args.cond_mode, dataset=args.dataset,
- num_samples=args.num_samples)
+ unconstrained=args.unconstrained, num_samples=args.num_samples)
motionloader = new_data_loader(mode="gen", dataiterator=dataiterator)
gt_motionloader = new_data_loader("gt", dataiterator=dataiterator)
gt_motionloader2 = new_data_loader("gt", dataiterator=dataiterator2)
# Action2motionEvaluation
loaders = {"gen": motionloader,
- # "recons": reconstructedloader,
"gt": gt_motionloader,
"gt2": gt_motionloader2}
a2mmetrics[seed] = a2mevaluation.evaluate(model, loaders)
+
del loaders
+ if args.unconstrained: # unconstrained
+ dataset_unconstrained = copy.deepcopy(data)
+ dataset_unconstrained.reset_shuffle()
+ dataset_unconstrained.shuffle()
+ dataiterator_unconstrained = DataLoader(dataset_unconstrained, batch_size=args.batch_size,
+ shuffle=False, num_workers=8, collate_fn=collate)
+ motionloader_unconstrained = new_data_loader(mode="gen", dataiterator=dataiterator_unconstrained, num_samples=num_samples_unconstrained)
+
+ generated_motions = []
+ for motion in motionloader_unconstrained:
+ idx = [15, 12, 16, 18, 20, 17, 19, 21, 0, 1, 4, 7, 2, 5, 8]
+ motion = motion['output_xyz'][:, idx, :, :]
+ generated_motions.append(motion.cpu().numpy())
+ generated_motions = np.concatenate(generated_motions)
+ unconstrained_metrics = evaluate_unconstrained_metrics(generated_motions, device, fast=True)
+ unconstrained_metrics = {k+'_unconstrained': v for k, v in unconstrained_metrics.items()}
+
except KeyboardInterrupt:
string = "Saving the evaluation before exiting.."
print(string)
metrics = {"feats": {key: [format_metrics(a2mmetrics[seed])[key] for seed in a2mmetrics.keys()] for key in a2mmetrics[allseeds[0]]}}
+ if args.unconstrained:
+ metrics["feats"] = {**metrics["feats"], **unconstrained_metrics}
return metrics
diff --git a/eval/eval_humanact12_uestc.py b/eval/eval_humanact12_uestc.py
index d77370d1..a6fa2b21 100644
--- a/eval/eval_humanact12_uestc.py
+++ b/eval/eval_humanact12_uestc.py
@@ -61,8 +61,6 @@ def main():
else:
args.num_samples = 1000
args.num_seeds = 20
- args.cond_mode = 'action' # temporary code till 'unconstrained' is implemented
-
data_loader = get_dataset_loader(name=args.dataset, num_frames=60, batch_size=args.batch_size,)
diff --git a/eval/unconstrained/evaluate.py b/eval/unconstrained/evaluate.py
new file mode 100644
index 00000000..0559628e
--- /dev/null
+++ b/eval/unconstrained/evaluate.py
@@ -0,0 +1,111 @@
+from eval.unconstrained.models.stgcn import STGCN
+import pandas as pd
+import os.path as osp
+import os
+import datetime
+
+import torch
+
+from torch.utils.data import DataLoader
+import numpy as np
+import sys as _sys
+from eval.a2m.action2motion.fid import calculate_fid
+from eval.a2m.action2motion.diversity import calculate_diversity
+from eval.unconstrained.metrics.kid import calculate_kid
+from eval.unconstrained.metrics.precision_recall import precision_and_recall
+from matplotlib import pyplot as plt
+
+TEST = False
+
+
+def initialize_model(device, modelpath):
+ num_classes = 12
+ model = STGCN(in_channels=3,
+ num_class=num_classes,
+ graph_args={"layout": 'openpose', "strategy": "spatial"},
+ edge_importance_weighting=True,
+ device=device)
+ model = model.to(device)
+ state_dict = torch.load(modelpath, map_location=device)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+def calculate_activation_statistics(activations):
+ activations = activations.cpu().detach().numpy()
+ mu = np.mean(activations, axis=0)
+ sigma = np.cov(activations, rowvar=False)
+ return mu, sigma
+
+
+def compute_features(model, iterator, device):
+ activations = []
+ predictions = []
+ with torch.no_grad():
+ for i, batch in enumerate(iterator):
+ batch_for_model = {}
+ batch_for_model['x'] = batch.to(device).float()
+ model(batch_for_model)
+ activations.append(batch_for_model['features'])
+ predictions.append(batch_for_model['yhat'])
+ # labels.append(batch_for_model['y'])
+ activations = torch.cat(activations, dim=0)
+ predictions = torch.cat(predictions, dim=0)
+ return activations, predictions
+
+
+def evaluate_unconstrained_metrics(generated_motions, device, fast):
+
+ act_rec_model_path = './assets/actionrecognition/humanact12_gru_modi_struct.pth.tar'
+ dataset_path = './dataset/HumanAct12Poses/humanact12_modi_struct.npy'
+
+ # initialize model
+ act_rec_model = initialize_model(device, act_rec_model_path)
+
+ generated_motions -= generated_motions[:, 8:9, :, :] # locate root joint of all frames at origin
+
+ iterator_generated = DataLoader(generated_motions, batch_size=64, shuffle=False, num_workers=8)
+
+ # compute features of generated motions
+ generated_features, generated_predictions = compute_features(act_rec_model, iterator_generated, device=device)
+ generated_stats = calculate_activation_statistics(generated_features)
+
+
+ # dataset motions
+ motion_data_raw = np.load(dataset_path, allow_pickle=True)
+ motion_data = motion_data_raw[:, :15] # data has 16 joints for back compitability with older formats
+ motion_data -= motion_data[:, 8:9, :, :] # locate root joint of all frames at origin
+ iterator_dataset = DataLoader(motion_data, batch_size=64, shuffle=False, num_workers=8)
+
+ # compute features of dataset motions
+ dataset_features, dataset_predictions = compute_features(act_rec_model, iterator_dataset, device=device)
+ real_stats = calculate_activation_statistics(dataset_features)
+
+ print("evaluation resutls:\n")
+
+ fid = calculate_fid(generated_stats, real_stats)
+ print(f"FID score: {fid}\n")
+
+ print("calculating KID...")
+ kid = calculate_kid(dataset_features.cpu(), generated_features.cpu())
+ (m, s) = kid
+ print('KID : %.3f (%.3f)\n' % (m, s))
+
+ dataset_diversity = calculate_diversity(dataset_features)
+ generated_diversity = calculate_diversity(generated_features)
+ print(f"Diversity of generated motions: {generated_diversity}")
+ print(f"Diversity of dataset motions: {dataset_diversity}\n")
+
+ if fast:
+ print("Skipping precision-recall calculation\n")
+ precision = recall = None
+ else:
+ print("calculating precision recall...")
+ precision, recall = precision_and_recall(generated_features, dataset_features)
+ print(f"precision: {precision}")
+ print(f"recall: {recall}\n")
+
+ metrics = {'fid': fid, 'kid': kid[0], 'diversity_gen': generated_diversity.cpu().item(), 'diversity_gt': dataset_diversity.cpu().item(),
+ 'precision': precision, 'recall':recall}
+ return metrics
+
diff --git a/eval/unconstrained/metrics/kid.py b/eval/unconstrained/metrics/kid.py
new file mode 100644
index 00000000..f56c63f3
--- /dev/null
+++ b/eval/unconstrained/metrics/kid.py
@@ -0,0 +1,136 @@
+import torch
+import numpy as np
+from tqdm import tqdm
+from sklearn.metrics.pairwise import polynomial_kernel
+import sys
+
+# from: https://github.com/abdulfatir/gan-metrics-pytorch/blob/master/kid_score.py
+def polynomial_mmd_averages(codes_g, codes_r, n_subsets=50, subset_size=1000,
+ ret_var=True, output=sys.stdout, **kernel_args):
+ m = min(codes_g.shape[0], codes_r.shape[0])
+ mmds = np.zeros(n_subsets)
+ if ret_var:
+ vars = np.zeros(n_subsets)
+ choice = np.random.choice
+
+ replace = subset_size < len(codes_g)
+ with tqdm(range(n_subsets), desc='MMD', file=output, disable=True) as bar:
+ for i in bar:
+ g = codes_g[choice(len(codes_g), subset_size, replace=replace)]
+ r = codes_r[choice(len(codes_r), subset_size, replace=replace)]
+ o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var)
+ if ret_var:
+ mmds[i], vars[i] = o
+ else:
+ mmds[i] = o
+ bar.set_postfix({'mean': mmds[:i+1].mean()})
+ return (mmds, vars) if ret_var else mmds
+
+def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1,
+ var_at_m=None, ret_var=True):
+ # use k(x, y) = (gamma + coef0)^degree
+ # default gamma is 1 / dim
+ X = codes_g
+ Y = codes_r
+
+ K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0)
+ K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0)
+ K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)
+
+ return _mmd2_and_variance(K_XX, K_XY, K_YY,
+ var_at_m=var_at_m, ret_var=ret_var)
+
+def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False,
+ mmd_est='unbiased', block_size=1024,
+ var_at_m=None, ret_var=True):
+ # based on
+ # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
+ # but changed to not compute the full kernel matrix at once
+ m = K_XX.shape[0]
+ assert K_XX.shape == (m, m)
+ assert K_XY.shape == (m, m)
+ assert K_YY.shape == (m, m)
+ if var_at_m is None:
+ var_at_m = m
+
+ # Get the various sums of kernels that we'll use
+ # Kts drop the diagonal, but we don't need to compute them explicitly
+ if unit_diagonal:
+ diag_X = diag_Y = 1
+ sum_diag_X = sum_diag_Y = m
+ sum_diag2_X = sum_diag2_Y = m
+ else:
+ diag_X = np.diagonal(K_XX)
+ diag_Y = np.diagonal(K_YY)
+
+ sum_diag_X = diag_X.sum()
+ sum_diag_Y = diag_Y.sum()
+
+ sum_diag2_X = _sqn(diag_X)
+ sum_diag2_Y = _sqn(diag_Y)
+
+ Kt_XX_sums = K_XX.sum(axis=1) - diag_X
+ Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
+ K_XY_sums_0 = K_XY.sum(axis=0)
+ K_XY_sums_1 = K_XY.sum(axis=1)
+
+ Kt_XX_sum = Kt_XX_sums.sum()
+ Kt_YY_sum = Kt_YY_sums.sum()
+ K_XY_sum = K_XY_sums_0.sum()
+
+ if mmd_est == 'biased':
+ mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
+ + (Kt_YY_sum + sum_diag_Y) / (m * m)
+ - 2 * K_XY_sum / (m * m))
+ else:
+ assert mmd_est in {'unbiased', 'u-statistic'}
+ mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1))
+ if mmd_est == 'unbiased':
+ mmd2 -= 2 * K_XY_sum / (m * m)
+ else:
+ mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1))
+
+ if not ret_var:
+ return mmd2
+
+ Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X
+ Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y
+ K_XY_2_sum = _sqn(K_XY)
+
+ dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1)
+ dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0)
+
+ m1 = m - 1
+ m2 = m - 2
+ zeta1_est = (
+ 1 / (m * m1 * m2) * (
+ _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum)
+ - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
+ + 1 / (m * m * m1) * (
+ _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum)
+ - 2 / m**4 * K_XY_sum**2
+ - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
+ + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
+ )
+ zeta2_est = (
+ 1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum)
+ - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
+ + 2 / (m * m) * K_XY_2_sum
+ - 2 / m**4 * K_XY_sum**2
+ - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
+ + 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
+ )
+ var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est
+ + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est)
+
+ return mmd2, var_est
+
+
+def _sqn(arr):
+ flat = np.ravel(arr)
+ return flat.dot(flat)
+
+def calculate_kid(real_activations, generated_activations):
+ kid_values = polynomial_mmd_averages(real_activations, generated_activations, n_subsets=100)
+ results = (kid_values[0].mean(), kid_values[0].std())
+ return results
diff --git a/eval/unconstrained/metrics/precision_recall.py b/eval/unconstrained/metrics/precision_recall.py
new file mode 100644
index 00000000..2edd4a91
--- /dev/null
+++ b/eval/unconstrained/metrics/precision_recall.py
@@ -0,0 +1,55 @@
+# based on https://github.com/blandocs/improved-precision-and-recall-metric-pytorch/blob/master/functions.py
+import os, torch
+import numpy as np
+import torch.nn as nn
+import torch.optim as optim
+from tqdm import tqdm
+
+# self.batch_size = args.batch_size
+# self.cpu = args.cpu
+# self.data_size = args.data_size
+
+def precision_and_recall(generated_features, real_features):
+ k = 3
+
+ data_num = min(len(generated_features), len(real_features))
+ print(f'data num: {data_num}')
+
+ if data_num <= 0:
+ print("there is no data")
+ return
+ generated_features = generated_features[:data_num]
+ real_features = real_features[:data_num]
+
+ # get precision and recall
+ precision = manifold_estimate(real_features, generated_features, k)
+ recall = manifold_estimate(generated_features, real_features, k)
+
+ return precision, recall
+
+def manifold_estimate( A_features, B_features, k):
+ A_features = list(A_features)
+ B_features = list(B_features)
+ KNN_list_in_A = {}
+ for A in tqdm(A_features, ncols=80):
+ pairwise_distances = np.zeros(shape=(len(A_features)))
+
+ for i, A_prime in enumerate(A_features):
+ d = torch.norm((A - A_prime), 2)
+ pairwise_distances[i] = d
+
+ v = np.partition(pairwise_distances, k)[k]
+ KNN_list_in_A[A] = v
+
+ n = 0
+
+ for B in tqdm(B_features, ncols=80):
+ for A_prime in A_features:
+ d = torch.norm((B - A_prime), 2)
+ if d <= KNN_list_in_A[A_prime]:
+ n += 1
+ break
+
+ return n / len(B_features)
+
+
diff --git a/eval/unconstrained/models/stgcn.py b/eval/unconstrained/models/stgcn.py
new file mode 100644
index 00000000..641d8148
--- /dev/null
+++ b/eval/unconstrained/models/stgcn.py
@@ -0,0 +1,221 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from eval.a2m.recognition.models.stgcnutils.tgcn import ConvTemporalGraphical
+from eval.unconstrained.models.stgcnutils.graph import Graph
+
+__all__ = ["STGCN"]
+
+
+class STGCN(nn.Module):
+ r"""Spatial temporal graph convolutional networks.
+ Args:
+ in_channels (int): Number of channels in the input data
+ num_class (int): Number of classes for the classification task
+ graph_args (dict): The arguments for building the graph
+ edge_importance_weighting (bool): If ``True``, adds a learnable
+ importance weighting to the edges of the graph
+ **kwargs (optional): Other parameters for graph convolution units
+ Shape:
+ - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
+ - Output: :math:`(N, num_class)` where
+ :math:`N` is a batch size,
+ :math:`T_{in}` is a length of input sequence,
+ :math:`V_{in}` is the number of graph nodes,
+ :math:`M_{in}` is the number of instance in a frame.
+ """
+
+ def __init__(self, in_channels, num_class, graph_args,
+ edge_importance_weighting, device, **kwargs):
+ super().__init__()
+
+ self.device = device
+ self.num_class = num_class
+
+ self.losses = ["accuracy", "cross_entropy", "mixed"]
+ self.criterion = torch.nn.CrossEntropyLoss(reduction='mean')
+
+ # load graph
+ self.graph = Graph(**graph_args)
+ A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
+ self.register_buffer('A', A)
+
+ # build networks
+ spatial_kernel_size = A.size(0)
+ temporal_kernel_size = 9
+ kernel_size = (temporal_kernel_size, spatial_kernel_size)
+ self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))
+ # self.data_bn = nn.BatchNorm1d(in_channels * A.size(1), track_running_stats=False)
+ # self.data_bn = nn.InstanceNorm1d(in_channels * A.size(1))
+ kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
+ self.st_gcn_networks = nn.ModuleList((
+ st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
+ st_gcn(64, 64, kernel_size, 1, **kwargs),
+ st_gcn(64, 64, kernel_size, 1, **kwargs),
+ # st_gcn(64, 64, kernel_size, 1, **kwargs),
+ st_gcn(64, 128, kernel_size, 2, **kwargs),
+ st_gcn(128, 128, kernel_size, 1, **kwargs),
+ # st_gcn(128, 128, kernel_size, 1, **kwargs),
+ st_gcn(128, 256, kernel_size, 2, **kwargs),
+ # st_gcn(256, 256, kernel_size, 1, **kwargs),
+ # st_gcn(256, 256, kernel_size, 1, **kwargs),
+ ))
+
+ # initialize parameters for edge importance weighting
+ if edge_importance_weighting:
+ self.edge_importance = nn.ParameterList([
+ nn.Parameter(torch.ones(self.A.size()))
+ for i in self.st_gcn_networks
+ ])
+ else:
+ self.edge_importance = [1] * len(self.st_gcn_networks)
+
+ # fcn for prediction
+ self.fcn = nn.Conv2d(256, num_class, kernel_size=1)
+
+ def forward(self, batch):
+ # TODO: use mask
+ # Received batch["x"] as
+ # Batch(48), Joints(23), Quat(4), Time(157
+ # Expecting:
+ # Batch, Quat:4, Time, Joints, 1
+ x = batch["x"].permute(0, 2, 3, 1).unsqueeze(4).contiguous()
+
+ # data normalization
+ N, C, T, V, M = x.size()
+ x = x.permute(0, 4, 3, 1, 2).contiguous()
+ x = x.view(N * M, V * C, T)
+ x = self.data_bn(x)
+ x = x.view(N, M, V, C, T)
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
+ x = x.view(N * M, C, T, V)
+
+ # forward
+ for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
+ x, _ = gcn(x, self.A * importance)
+
+ # compute feature
+ # _, c, t, v = x.size()
+ # features = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1)
+ # batch["features"] = features
+
+ # global pooling
+ x = F.avg_pool2d(x, x.size()[2:])
+ x = x.view(N, M, -1, 1, 1).mean(dim=1)
+
+ # features
+ batch["features"] = x.squeeze()
+
+ # prediction
+ x = self.fcn(x)
+ x = x.view(x.size(0), -1)
+ batch["yhat"] = x
+ return batch
+
+ def compute_accuracy(self, batch):
+ confusion = torch.zeros(self.num_class, self.num_class, dtype=int)
+ yhat = batch["yhat"].max(dim=1).indices
+ ygt = batch["y"]
+ for label, pred in zip(ygt, yhat):
+ confusion[label][pred] += 1
+ accuracy = torch.trace(confusion)/torch.sum(confusion)
+ return accuracy
+
+ def compute_loss(self, batch):
+ cross_entropy = self.criterion(batch["yhat"], batch["y"])
+ mixed_loss = cross_entropy
+
+ acc = self.compute_accuracy(batch)
+ losses = {"cross_entropy": cross_entropy.item(),
+ "mixed": mixed_loss.item(),
+ "accuracy": acc.item()}
+ return mixed_loss, losses
+
+
+class st_gcn(nn.Module):
+ r"""Applies a spatial temporal graph convolution over an input graph sequence.
+ Args:
+ in_channels (int): Number of channels in the input sequence data
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
+ stride (int, optional): Stride of the temporal convolution. Default: 1
+ dropout (int, optional): Dropout rate of the final output. Default: 0
+ residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
+ Shape:
+ - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
+ - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
+ - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
+ - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
+ where
+ :math:`N` is a batch size,
+ :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
+ :math:`T_{in}/T_{out}` is a length of input/output sequence,
+ :math:`V` is the number of graph nodes.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ dropout=0,
+ residual=True):
+ super().__init__()
+
+ assert len(kernel_size) == 2
+ assert kernel_size[0] % 2 == 1
+ padding = ((kernel_size[0] - 1) // 2, 0)
+
+ self.gcn = ConvTemporalGraphical(in_channels, out_channels,
+ kernel_size[1])
+
+ self.tcn = nn.Sequential(
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(
+ out_channels,
+ out_channels,
+ (kernel_size[0], 1),
+ (stride, 1),
+ padding,
+ ),
+ nn.BatchNorm2d(out_channels),
+ nn.Dropout(dropout, inplace=True),
+ )
+
+ if not residual:
+ self.residual = lambda x: 0
+
+ elif (in_channels == out_channels) and (stride == 1):
+ self.residual = lambda x: x
+
+ else:
+ self.residual = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=(stride, 1)),
+ nn.BatchNorm2d(out_channels),
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x, A):
+
+ res = self.residual(x)
+ x, A = self.gcn(x, A)
+ x = self.tcn(x) + res
+
+ return self.relu(x), A
+
+
+if __name__ == "__main__":
+ model = STGCN(in_channels=3, num_class=60, edge_importance_weighting=True, graph_args={"layout": "smpl_noglobal", "strategy": "spatial"})
+ # Batch, in_channels, time, vertices, M
+ inp = torch.rand(10, 3, 16, 23, 1)
+ out = model(inp)
+ print(out.shape)
+ import pdb
+ pdb.set_trace()
diff --git a/eval/unconstrained/models/stgcnutils/graph.py b/eval/unconstrained/models/stgcnutils/graph.py
new file mode 100644
index 00000000..8ad28036
--- /dev/null
+++ b/eval/unconstrained/models/stgcnutils/graph.py
@@ -0,0 +1,185 @@
+import numpy as np
+import pickle as pkl
+
+from utils.config import SMPL_KINTREE_PATH
+
+
+class Graph:
+ """ The Graph to model the skeletons extracted by the openpose
+ Args:
+ strategy (string): must be one of the follow candidates
+ - uniform: Uniform Labeling
+ - distance: Distance Partitioning
+ - spatial: Spatial Configuration
+ For more information, please refer to the section 'Partition Strategies'
+ in our paper (https://arxiv.org/abs/1801.07455).
+ layout (string): must be one of the follow candidates
+ - openpose: Is consists of 18 joints. For more information, please
+ refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output
+ - ntu-rgb+d: Is consists of 25 joints. For more information, please
+ refer to https://github.com/shahroudy/NTURGB-D
+ - smpl: Consists of 24/23 joints with without global rotation.
+ max_hop (int): the maximal distance between two connected nodes
+ dilation (int): controls the spacing between the kernel points
+ """
+
+ def __init__(self,
+ layout='openpose',
+ strategy='uniform',
+ kintree_path=SMPL_KINTREE_PATH,
+ max_hop=1,
+ dilation=1):
+ self.max_hop = max_hop
+ self.dilation = dilation
+
+ self.kintree_path = kintree_path
+
+ self.get_edge(layout)
+ self.hop_dis = get_hop_distance(
+ self.num_node, self.edge, max_hop=max_hop)
+ self.get_adjacency(strategy)
+
+ def __str__(self):
+ return self.A
+
+ def get_edge(self, layout):
+ if layout == 'openpose':
+ # self.num_node = 18
+ self.num_node = 15
+ self_link = [(i, i) for i in range(self.num_node)]
+ # neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,
+ # 11),
+ # (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),
+ # (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]
+ neighbor_link = [(4, 3), (3, 2), (2, 1),
+ (7, 6), (6, 5), (5, 1),
+ (1, 0),
+ (14, 13), (13, 12), (12, 8),
+ (11, 10), (10, 9), (9, 8),
+ (8, 1),]
+ self.edge = self_link + neighbor_link
+ self.center = 1
+ elif layout == 'smpl':
+ self.num_node = 24
+ self_link = [(i, i) for i in range(self.num_node)]
+ kt = pkl.load(open(self.kintree_path, "rb"))
+ neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])]
+ self.edge = self_link + neighbor_link
+ self.center = 0
+ elif layout == 'smpl_noglobal':
+ self.num_node = 23
+ self_link = [(i, i) for i in range(self.num_node)]
+ kt = pkl.load(open(self.kintree_path, "rb"))
+ neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])]
+ # remove the root joint
+ neighbor_1base = [n for n in neighbor_link if n[0] != 0 and n[1] != 0]
+ neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
+ self.edge = self_link + neighbor_link
+ self.center = 0
+ elif layout == 'ntu-rgb+d':
+ self.num_node = 25
+ self_link = [(i, i) for i in range(self.num_node)]
+ neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),
+ (6, 5), (7, 6), (8, 7), (9, 21), (10, 9),
+ (11, 10), (12, 11), (13, 1), (14, 13), (15, 14),
+ (16, 15), (17, 1), (18, 17), (19, 18), (20, 19),
+ (22, 23), (23, 8), (24, 25), (25, 12)]
+ neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
+ self.edge = self_link + neighbor_link
+ self.center = 21 - 1
+ elif layout == 'ntu_edge':
+ self.num_node = 24
+ self_link = [(i, i) for i in range(self.num_node)]
+ neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6),
+ (8, 7), (9, 2), (10, 9), (11, 10), (12, 11),
+ (13, 1), (14, 13), (15, 14), (16, 15), (17, 1),
+ (18, 17), (19, 18), (20, 19), (21, 22), (22, 8),
+ (23, 24), (24, 12)]
+ neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
+ self.edge = self_link + neighbor_link
+ self.center = 2
+ # elif layout=='customer settings'
+ # pass
+ else:
+ raise NotImplementedError("This Layout is not supported")
+
+ def get_adjacency(self, strategy):
+ valid_hop = range(0, self.max_hop + 1, self.dilation)
+ adjacency = np.zeros((self.num_node, self.num_node))
+ for hop in valid_hop:
+ adjacency[self.hop_dis == hop] = 1
+ normalize_adjacency = normalize_digraph(adjacency)
+
+ if strategy == 'uniform':
+ A = np.zeros((1, self.num_node, self.num_node))
+ A[0] = normalize_adjacency
+ self.A = A
+ elif strategy == 'distance':
+ A = np.zeros((len(valid_hop), self.num_node, self.num_node))
+ for i, hop in enumerate(valid_hop):
+ A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop]
+ self.A = A
+ elif strategy == 'spatial':
+ A = []
+ for hop in valid_hop:
+ a_root = np.zeros((self.num_node, self.num_node))
+ a_close = np.zeros((self.num_node, self.num_node))
+ a_further = np.zeros((self.num_node, self.num_node))
+ for i in range(self.num_node):
+ for j in range(self.num_node):
+ if self.hop_dis[j, i] == hop:
+ if self.hop_dis[j, self.center] == self.hop_dis[
+ i, self.center]:
+ a_root[j, i] = normalize_adjacency[j, i]
+ elif self.hop_dis[j, self.
+ center] > self.hop_dis[i, self.
+ center]:
+ a_close[j, i] = normalize_adjacency[j, i]
+ else:
+ a_further[j, i] = normalize_adjacency[j, i]
+ if hop == 0:
+ A.append(a_root)
+ else:
+ A.append(a_root + a_close)
+ A.append(a_further)
+ A = np.stack(A)
+ self.A = A
+ else:
+ raise NotImplementedError("This Strategy is not supported")
+
+
+def get_hop_distance(num_node, edge, max_hop=1):
+ A = np.zeros((num_node, num_node))
+ for i, j in edge:
+ A[j, i] = 1
+ A[i, j] = 1
+
+ # compute hop steps
+ hop_dis = np.zeros((num_node, num_node)) + np.inf
+ transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
+ arrive_mat = (np.stack(transfer_mat) > 0)
+ for d in range(max_hop, -1, -1):
+ hop_dis[arrive_mat[d]] = d
+ return hop_dis
+
+
+def normalize_digraph(A):
+ Dl = np.sum(A, 0)
+ num_node = A.shape[0]
+ Dn = np.zeros((num_node, num_node))
+ for i in range(num_node):
+ if Dl[i] > 0:
+ Dn[i, i] = Dl[i]**(-1)
+ AD = np.dot(A, Dn)
+ return AD
+
+
+def normalize_undigraph(A):
+ Dl = np.sum(A, 0)
+ num_node = A.shape[0]
+ Dn = np.zeros((num_node, num_node))
+ for i in range(num_node):
+ if Dl[i] > 0:
+ Dn[i, i] = Dl[i]**(-0.5)
+ DAD = np.dot(np.dot(Dn, A), Dn)
+ return DAD
diff --git a/model/mdm.py b/model/mdm.py
index 4739a8a7..14fd5bda 100644
--- a/model/mdm.py
+++ b/model/mdm.py
@@ -143,8 +143,6 @@ def forward(self, x, timesteps, y=None):
x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
timesteps: [batch_size] (int)
"""
- assert (y is not None) == (self.cond_mode != 'no_cond'
- ), "must specify y if and only if the model is class-conditional"
bs, njoints, nfeats, nframes = x.shape
emb = self.embed_timestep(timesteps) # [1, bs, d]
diff --git a/prepare/download_recognition_unconstrained_models.sh b/prepare/download_recognition_unconstrained_models.sh
new file mode 100644
index 00000000..487471d4
--- /dev/null
+++ b/prepare/download_recognition_unconstrained_models.sh
@@ -0,0 +1,8 @@
+mkdir -p assets/actionrecognition/
+cd assets/actionrecognition/
+
+echo -e "Downloading the HumanAct12 action recognition model, adjusted for the unconstrained setting."
+gdown "1xfigimkPxKt3a8zvn_ME_NAR6CyTqneK"
+echo -e
+
+echo -e "Downloading done!"
diff --git a/prepare/download_unconstrained_datasets.sh b/prepare/download_unconstrained_datasets.sh
new file mode 100644
index 00000000..8abf7e92
--- /dev/null
+++ b/prepare/download_unconstrained_datasets.sh
@@ -0,0 +1,10 @@
+mkdir -p dataset/HumanAct12Poses
+cd dataset/HumanAct12Poses
+
+echo "The datasets will be stored in the 'dataset' folder\n"
+
+# HumanAct12 poses unconstrained
+echo "Downloading the HumanAct12 unconstrained poses dataset"
+gdown "1KqOBTtLFgkvWSZb8ao-wdBMG7sTP3Q7d"
+
+echo "Downloading done!"
diff --git a/sample/generate.py b/sample/generate.py
index 3c93dc7f..9176d5bb 100644
--- a/sample/generate.py
+++ b/sample/generate.py
@@ -137,8 +137,11 @@ def main():
jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None,
get_rotations_back=False)
- text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text'
- all_text += model_kwargs['y'][text_key]
+ if args.unconstrained:
+ all_text += ['unconstrained'] * args.num_samples
+ else:
+ text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text'
+ all_text += model_kwargs['y'][text_key]
all_motions.append(sample.cpu().numpy())
all_lengths.append(model_kwargs['y']['lengths'].cpu().numpy())
@@ -170,41 +173,75 @@ def main():
sample_files = []
num_samples_in_out_file = 7
+
+ sample_print_template, row_print_template, all_print_template, \
+ sample_file_template, row_file_template, all_file_template = construct_template_variables(args.unconstrained)
+
for sample_i in range(args.num_samples):
rep_files = []
for rep_i in range(args.num_repetitions):
caption = all_text[rep_i*args.batch_size + sample_i]
length = all_lengths[rep_i*args.batch_size + sample_i]
motion = all_motions[rep_i*args.batch_size + sample_i].transpose(2, 0, 1)[:length]
- save_file = 'sample{:02d}_rep{:02d}.mp4'.format(sample_i, rep_i)
+ save_file = sample_file_template.format(sample_i, rep_i)
+ print(sample_print_template.format(caption, sample_i, rep_i, save_file))
animation_save_path = os.path.join(out_path, save_file)
- print(f'[({sample_i}) "{caption}" | Rep #{rep_i} | -> {save_file}]')
plot_3d_motion(animation_save_path, skeleton, motion, dataset=args.dataset, title=caption, fps=fps)
# Credit for visualization: https://github.com/EricGuo5513/text-to-motion
-
rep_files.append(animation_save_path)
- all_rep_save_file = os.path.join(out_path, 'sample{:02d}.mp4'.format(sample_i))
- ffmpeg_rep_files = [f' -i {f} ' for f in rep_files]
- hstack_args = f' -filter_complex hstack=inputs={args.num_repetitions}' if args.num_repetitions > 1 else ''
- ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{hstack_args} {all_rep_save_file}'
- os.system(ffmpeg_rep_cmd)
- print(f'[({sample_i}) "{caption}" | all repetitions | -> {all_rep_save_file}]')
- sample_files.append(all_rep_save_file)
-
- if (sample_i+1) % num_samples_in_out_file == 0 or sample_i+1 == args.num_samples:
- all_sample_save_file = os.path.join(out_path, f'samples_{(sample_i - len(sample_files) + 1):02d}_to_{sample_i:02d}.mp4')
- ffmpeg_rep_files = [f' -i {f} ' for f in sample_files]
- vstack_args = f' -filter_complex vstack=inputs={len(sample_files)}' if len(sample_files) > 1 else ''
- ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{vstack_args} {all_sample_save_file}'
- os.system(ffmpeg_rep_cmd)
- print(f'[(samples {(sample_i - len(sample_files) + 1):02d} to {sample_i:02d}) | all repetitions | -> {all_sample_save_file}]')
- sample_files = []
+ sample_files = save_multiple_samples(args, out_path,
+ row_print_template, all_print_template, row_file_template, all_file_template,
+ caption, num_samples_in_out_file, rep_files, sample_files, sample_i)
abs_path = os.path.abspath(out_path)
print(f'[Done] Results are at [{abs_path}]')
+def save_multiple_samples(args, out_path, row_print_template, all_print_template, row_file_template, all_file_template,
+ caption, num_samples_in_out_file, rep_files, sample_files, sample_i):
+ all_rep_save_file = row_file_template.format(sample_i)
+ all_rep_save_path = os.path.join(out_path, all_rep_save_file)
+ ffmpeg_rep_files = [f' -i {f} ' for f in rep_files]
+ hstack_args = f' -filter_complex hstack=inputs={args.num_repetitions}' if args.num_repetitions > 1 else ''
+ ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{hstack_args} {all_rep_save_path}'
+ os.system(ffmpeg_rep_cmd)
+ print(row_print_template.format(caption, sample_i, all_rep_save_file))
+ sample_files.append(all_rep_save_path)
+ if (sample_i + 1) % num_samples_in_out_file == 0 or sample_i + 1 == args.num_samples:
+ # all_sample_save_file = f'samples_{(sample_i - len(sample_files) + 1):02d}_to_{sample_i:02d}.mp4'
+ all_sample_save_file = all_file_template.format(sample_i - len(sample_files) + 1, sample_i)
+ all_sample_save_path = os.path.join(out_path, all_sample_save_file)
+ print(all_print_template.format(sample_i - len(sample_files) + 1, sample_i, all_sample_save_file))
+ ffmpeg_rep_files = [f' -i {f} ' for f in sample_files]
+ vstack_args = f' -filter_complex vstack=inputs={len(sample_files)}' if len(sample_files) > 1 else ''
+ ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(
+ ffmpeg_rep_files) + f'{vstack_args} {all_sample_save_path}'
+ os.system(ffmpeg_rep_cmd)
+ sample_files = []
+ return sample_files
+
+
+def construct_template_variables(unconstrained):
+ row_file_template = 'sample{:02d}.mp4'
+ all_file_template = 'samples_{:02d}_to_{:02d}.mp4'
+ if unconstrained:
+ sample_file_template = 'row{:02d}_col{:02d}.mp4'
+ sample_print_template = '[{} row #{:02d} column #{:02d} | -> {}]'
+ row_file_template = row_file_template.replace('sample', 'row')
+ row_print_template = '[{} row #{:02d} | all columns | -> {}]'
+ all_file_template = all_file_template.replace('samples', 'rows')
+ all_print_template = '[rows {:02d} to {:02d} | -> {}]'
+ else:
+ sample_file_template = 'sample{:02d}_rep{:02d}.mp4'
+ sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]'
+ row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]'
+ all_print_template = '[samples {:02d} to {:02d} | all repetitions | -> {}]'
+
+ return sample_print_template, row_print_template, all_print_template, \
+ sample_file_template, row_file_template, all_file_template
+
+
def load_dataset(args, max_frames, n_frames):
data = get_dataset_loader(name=args.dataset,
batch_size=args.batch_size,
diff --git a/train/training_loop.py b/train/training_loop.py
index 1590c81c..fea86e1f 100644
--- a/train/training_loop.py
+++ b/train/training_loop.py
@@ -164,6 +164,8 @@ def run_loop(self):
self.evaluate()
def evaluate(self):
+ if not self.args.eval_during_training:
+ return
start_eval = time.time()
if self.eval_wrapper is not None:
print('Running evaluation loop: [Should take about 90 min]')
@@ -187,12 +189,15 @@ def evaluate(self):
elif self.dataset in ['humanact12', 'uestc']:
eval_args = SimpleNamespace(num_seeds=self.args.eval_rep_times, num_samples=self.args.eval_num_samples,
batch_size=self.args.eval_batch_size, device=self.device, guidance_param = 1,
- dataset=self.dataset, cond_mode='action',
+ dataset=self.dataset, unconstrained=self.args.unconstrained,
model_path=os.path.join(self.save_dir, self.ckpt_file_name()))
eval_dict = eval_humanact12_uestc.evaluate(eval_args, model=self.model, diffusion=self.diffusion, data=self.data.dataset)
print(f'Evaluation results on {self.dataset}: {sorted(eval_dict["feats"].items())}')
for k, v in eval_dict["feats"].items():
- self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval')
+ if 'unconstrained' not in k:
+ self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval')
+ else:
+ self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval Unconstrained')
end_eval = time.time()
print(f'Evaluation time: {round(end_eval-start_eval)/60}min')
diff --git a/utils/model_util.py b/utils/model_util.py
index f4ef3516..fd697b07 100644
--- a/utils/model_util.py
+++ b/utils/model_util.py
@@ -20,7 +20,12 @@ def get_model_args(args, data):
# default args
clip_version = 'ViT-B/32'
action_emb = 'tensor'
- cond_mode = 'text' if args.dataset in ['kit', 'humanml'] else 'action'
+ if args.unconstrained:
+ cond_mode = 'no_cond'
+ elif args.dataset in ['kit', 'humanml']:
+ cond_mode = 'text'
+ else:
+ cond_mode = 'action'
if hasattr(data.dataset, 'num_actions'):
num_actions = data.dataset.num_actions
else:
diff --git a/utils/parser_util.py b/utils/parser_util.py
index 3514a0a9..bdd3b6c4 100644
--- a/utils/parser_util.py
+++ b/utils/parser_util.py
@@ -21,11 +21,17 @@ def parse_and_load_from_model(parser):
assert os.path.exists(args_path), 'Arguments json file was not found!'
with open(args_path, 'r') as fr:
model_args = json.load(fr)
+
for a in args_to_overwrite:
if a in model_args.keys():
- args.__dict__[a] = model_args[a]
+ setattr(args, a, model_args[a])
+
+ elif 'cond_mode' in model_args and model_args['cond_mode'] == 'no_cond': # backward compitability
+ setattr(args, 'unconstrained', True)
+
else:
print('Warning: was not able to load [{}], using default value [{}] instead.'.format(a, args.__dict__[a]))
+
if args.cond_mask_prob == 0:
args.guidance_param = 1
return args
@@ -83,6 +89,9 @@ def add_model_options(parser):
group.add_argument("--lambda_rcxyz", default=0.0, type=float, help="Joint positions loss.")
group.add_argument("--lambda_vel", default=0.0, type=float, help="Joint velocity loss.")
group.add_argument("--lambda_fc", default=0.0, type=float, help="Foot contact loss.")
+ group.add_argument("--unconstrained", action='store_true',
+ help="Model is trained unconditionally. That is, it is constrained by neither text nor action. "
+ "Currently tested on HumanAct12 only.")