diff --git a/README.md b/README.md index 329f2ff3..9b6fe304 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,10 @@ If you find this code useful in your research, please cite: ## News +📢 **4/Nov/22** - Added sampling, training and evaluation of unconstrained tasks. + Note slight env changes adapting to the new code. If you already have an installed environment, run `bash prepare/download_unconstrained_assets.sh; conda install -y -c anaconda scikit-learn +` to adapt. + 📢 **3/Nov/22** - Added in-between and upper-body editing. 📢 **31/Oct/22** - Added sampling, training and evaluation of action-to-motion tasks. @@ -33,10 +37,6 @@ If you find this code useful in your research, please cite: 📢 **6/Oct/22** - First release - sampling and rendering using pre-trained models. -## ETAs - -* Unconstrained Motion: Nov 22 - ## Getting started @@ -76,15 +76,23 @@ bash prepare/download_glove.sh
- Text to Motion, Unconstrained + Action to Motion ```bash bash prepare/download_smpl_files.sh -bash prepare/download_a2m_datasets.sh bash prepare/download_recognition_models.sh ```
+
+ Unconstrained + +```bash +bash prepare/download_smpl_files.sh +bash prepare/download_recognition_unconstrained_models.sh +``` +
+ ### 2. Get data
@@ -125,12 +133,21 @@ cp -r ../HumanML3D/HumanML3D ./dataset/HumanML3D
Action to Motion -**UESTC, HumanAct12** : +**UESTC, HumanAct12** ```bash bash prepare/download_a2m_datasets.sh ```
+
+ Unconstrained + +**HumanAct12** +```bash +bash prepare/download_unconstrained_datasets.sh +``` +
+ ### 3. Download the pretrained models Download the model(s) you wish to use, then unzip and place them in `./save/`. @@ -171,6 +188,15 @@ Download the model(s) you wish to use, then unzip and place them in `./save/`.
+
+ Unconstrained + +**HumanAct12** + +[humanact12_unconstrained](https://drive.google.com/file/d/1uG68m200pZK3pD-zTmPXu5XkgNpx_mEx/view?usp=share_link) + +
+ ## Motion Synthesis
@@ -217,6 +243,16 @@ python -m sample.generate --model_path ./save/humanact12/model000350000.pt --tex ```
+
+ Unconstrained + +```shell +python -m sample.generate --model_path ./save/unconstrained/model000450000.pt --num_samples 10 --num_repetitions 3 +``` + +By abuse of notation, (num_samples * num_repetitions) samples are created, and are visually organized in a display of num_samples rows and num_repetitions columns. + +
**You may also define:** * `--device` id. @@ -317,6 +353,14 @@ python -m train.train_mdm --save_dir save/my_name --dataset {humanact12,uestc} - ``` +
+ Unconstrained + +```shell +python -m train.train_mdm --save_dir save/my_name --dataset humanact12 --cond_mask_prob 0 --lambda_rcxyz 1 --lambda_vel 1 --lambda_fc 1 --unconstrained +``` +
+ * Use `--device` to define GPU id. * Use `--arch` to choose one of the architectures reported in the paper `{trans_enc, trans_dec, gru}` (`trans_enc` is default). * Add `--train_platform_type {ClearmlPlatform, TensorboardPlatform}` to track results with either [ClearML](https://clear.ml/) or [Tensorboard](https://www.tensorflow.org/tensorboard). @@ -349,19 +393,32 @@ python -m eval.eval_humanml --model_path ./save/kit_trans_enc_512/model000400000 * The output of this script for the pre-trained models (as was reported in the paper) is provided in the checkpoints zip file. ```shell ---model --eval_mode full +python -m eval.eval_humanact12_uestc --model --eval_mode full ``` where `path-to-model-ckpt` can be a path to any of the pretrained action-to-motion models listed above, or to a checkpoint trained by the user. +
+ Unconstrained + +* Takes about 3 hours (on a single GPU) + +```shell +python -m eval.eval_humanact12_uestc --model ./save/unconstrained/model000450000.pt --eval_mode full +``` + +Precision and recall are not computed to save computing time. If you wish to compute them, edit the file eval/a2m/gru_eval.py and change the string `fast=True` to `fast=False`. + +
+ ## Acknowledgments This code is standing on the shoulders of giants. We want to thank the following contributors that our code is based on: -[guided-diffusion](https://github.com/openai/guided-diffusion), [MotionCLIP](https://github.com/GuyTevet/MotionCLIP), [text-to-motion](https://github.com/EricGuo5513/text-to-motion), [actor](https://github.com/Mathux/ACTOR), [joints2smpl](https://github.com/wangsen1312/joints2smpl). +[guided-diffusion](https://github.com/openai/guided-diffusion), [MotionCLIP](https://github.com/GuyTevet/MotionCLIP), [text-to-motion](https://github.com/EricGuo5513/text-to-motion), [actor](https://github.com/Mathux/ACTOR), [joints2smpl](https://github.com/wangsen1312/joints2smpl), [MoDi](https://github.com/sigal-raab/MoDi). ## License This code is distributed under an [MIT LICENSE](LICENSE). diff --git a/environment.yml b/environment.yml index 11092210..dceaae7a 100644 --- a/environment.yml +++ b/environment.yml @@ -1,6 +1,7 @@ name: mdm channels: - pytorch + - anaconda - conda-forge - defaults dependencies: @@ -9,9 +10,9 @@ dependencies: - beautifulsoup4=4.11.1=pyha770c72_0 - blas=1.0=mkl - brotlipy=0.7.0=py37h540881e_1004 - - ca-certificates=2022.9.24=ha878542_0 + - ca-certificates=2022.07.19=h06a4308_0 - catalogue=2.0.8=py37h89c1867_0 - - certifi=2022.9.24=pyhd8ed1ab_0 + - certifi=2022.6.15=py37h06a4308_0 - cffi=1.15.1=py37h74dc2b5_0 - charset-normalizer=2.1.1=pyhd8ed1ab_0 - colorama=0.4.5=pyhd8ed1ab_0 @@ -37,6 +38,7 @@ dependencies: - idna=3.4=pyhd8ed1ab_0 - intel-openmp=2021.4.0=h06a4308_3561 - jinja2=3.1.2=pyhd8ed1ab_1 + - joblib=1.1.0=pyhd3eb1b0_0 - jpeg=9b=h024ee3a_2 - kiwisolver=1.4.2=py37h295c915_0 - langcodes=3.3.0=pyhd8ed1ab_0 @@ -87,6 +89,7 @@ dependencies: - qt=5.9.7=h5867ecd_1 - readline=8.1.2=h7f8727e_1 - requests=2.28.1=pyhd8ed1ab_1 + - scikit-learn=1.0.2=py37h51133e4_1 - scipy=1.7.3=py37h6c91a56_2 - setuptools=63.4.1=py37h06a4308_0 - shellingham=1.5.0=pyhd8ed1ab_0 @@ -98,6 +101,7 @@ dependencies: - spacy-legacy=3.0.10=pyhd8ed1ab_0 - spacy-loggers=1.0.3=pyhd8ed1ab_0 - sqlite=3.39.3=h5082296_0 + - threadpoolctl=2.2.0=pyh0d69192_0 - tk=8.6.12=h1ccaba5_0 - torchaudio=0.7.2=py37 - torchvision=0.8.2=py37_cu110 @@ -113,10 +117,8 @@ dependencies: - pip: - blis==0.7.8 - chumpy==0.70 - - clearml==1.7.1 - click==8.1.3 - confection==0.0.2 - - filelock==3.8.0 - ftfy==6.1.1 - importlib-metadata==5.0.0 - lxml==4.9.1 diff --git a/eval/a2m/action2motion/diversity.py b/eval/a2m/action2motion/diversity.py index 5ebe8a45..c2011080 100644 --- a/eval/a2m/action2motion/diversity.py +++ b/eval/a2m/action2motion/diversity.py @@ -2,6 +2,21 @@ import numpy as np +#adapted from action2motion +def calculate_diversity(activations): + diversity_times = 200 + num_motions = len(activations) + + diversity = 0 + + first_indices = np.random.randint(0, num_motions, diversity_times) + second_indices = np.random.randint(0, num_motions, diversity_times) + for first_idx, second_idx in zip(first_indices, second_indices): + diversity += torch.dist(activations[first_idx, :], + activations[second_idx, :]) + diversity /= diversity_times + return diversity + # from action2motion def calculate_diversity_multimodality(activations, labels, num_labels, unconstrained = False): diversity_times = 200 diff --git a/eval/a2m/gru_eval.py b/eval/a2m/gru_eval.py index 1595f0b9..32267399 100644 --- a/eval/a2m/gru_eval.py +++ b/eval/a2m/gru_eval.py @@ -1,5 +1,7 @@ import copy import os + +import numpy as np from tqdm import tqdm import torch import functools @@ -8,13 +10,14 @@ from utils.fixseed import fixseed from data_loaders.tensors import collate from eval.a2m.action2motion.evaluate import A2MEvaluation +from eval.unconstrained.evaluate import evaluate_unconstrained_metrics from .tools import save_metrics, format_metrics -from data_loaders.get_data import get_dataset from utils import dist_util +num_samples_unconstrained = 1000 class NewDataloader: - def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, dataset, num_samples: int=-1): + def __init__(self, mode, model, diffusion, dataiterator, device, unconstrained, num_samples: int=-1): assert mode in ["gen", "gt"] self.batches = [] sample_fn = diffusion.p_sample_loop @@ -37,7 +40,7 @@ def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, data translation=True, jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, get_rotations_back=False) batch["lengths"] = model_kwargs['y']['lengths'].to(device) - if cond_mode != 'no_cond': # proceed only if not running unconstrained + if unconstrained: # proceed only if not running unconstrained batch["y"] = model_kwargs['y']['action'].squeeze().long().cpu() # using torch.long so lengths/action will be used as indices self.batches.append(batch) @@ -49,7 +52,6 @@ def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, data def __iter__(self): return iter(self.batches) - def evaluate(args, model, diffusion, data): num_frames = 60 @@ -87,25 +89,43 @@ def evaluate(args, model, diffusion, data): shuffle=False, num_workers=8, collate_fn=collate) new_data_loader = functools.partial(NewDataloader, model=model, diffusion=diffusion, device=device, - cond_mode=args.cond_mode, dataset=args.dataset, - num_samples=args.num_samples) + unconstrained=args.unconstrained, num_samples=args.num_samples) motionloader = new_data_loader(mode="gen", dataiterator=dataiterator) gt_motionloader = new_data_loader("gt", dataiterator=dataiterator) gt_motionloader2 = new_data_loader("gt", dataiterator=dataiterator2) # Action2motionEvaluation loaders = {"gen": motionloader, - # "recons": reconstructedloader, "gt": gt_motionloader, "gt2": gt_motionloader2} a2mmetrics[seed] = a2mevaluation.evaluate(model, loaders) + del loaders + if args.unconstrained: # unconstrained + dataset_unconstrained = copy.deepcopy(data) + dataset_unconstrained.reset_shuffle() + dataset_unconstrained.shuffle() + dataiterator_unconstrained = DataLoader(dataset_unconstrained, batch_size=args.batch_size, + shuffle=False, num_workers=8, collate_fn=collate) + motionloader_unconstrained = new_data_loader(mode="gen", dataiterator=dataiterator_unconstrained, num_samples=num_samples_unconstrained) + + generated_motions = [] + for motion in motionloader_unconstrained: + idx = [15, 12, 16, 18, 20, 17, 19, 21, 0, 1, 4, 7, 2, 5, 8] + motion = motion['output_xyz'][:, idx, :, :] + generated_motions.append(motion.cpu().numpy()) + generated_motions = np.concatenate(generated_motions) + unconstrained_metrics = evaluate_unconstrained_metrics(generated_motions, device, fast=True) + unconstrained_metrics = {k+'_unconstrained': v for k, v in unconstrained_metrics.items()} + except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) metrics = {"feats": {key: [format_metrics(a2mmetrics[seed])[key] for seed in a2mmetrics.keys()] for key in a2mmetrics[allseeds[0]]}} + if args.unconstrained: + metrics["feats"] = {**metrics["feats"], **unconstrained_metrics} return metrics diff --git a/eval/eval_humanact12_uestc.py b/eval/eval_humanact12_uestc.py index d77370d1..a6fa2b21 100644 --- a/eval/eval_humanact12_uestc.py +++ b/eval/eval_humanact12_uestc.py @@ -61,8 +61,6 @@ def main(): else: args.num_samples = 1000 args.num_seeds = 20 - args.cond_mode = 'action' # temporary code till 'unconstrained' is implemented - data_loader = get_dataset_loader(name=args.dataset, num_frames=60, batch_size=args.batch_size,) diff --git a/eval/unconstrained/evaluate.py b/eval/unconstrained/evaluate.py new file mode 100644 index 00000000..0559628e --- /dev/null +++ b/eval/unconstrained/evaluate.py @@ -0,0 +1,111 @@ +from eval.unconstrained.models.stgcn import STGCN +import pandas as pd +import os.path as osp +import os +import datetime + +import torch + +from torch.utils.data import DataLoader +import numpy as np +import sys as _sys +from eval.a2m.action2motion.fid import calculate_fid +from eval.a2m.action2motion.diversity import calculate_diversity +from eval.unconstrained.metrics.kid import calculate_kid +from eval.unconstrained.metrics.precision_recall import precision_and_recall +from matplotlib import pyplot as plt + +TEST = False + + +def initialize_model(device, modelpath): + num_classes = 12 + model = STGCN(in_channels=3, + num_class=num_classes, + graph_args={"layout": 'openpose', "strategy": "spatial"}, + edge_importance_weighting=True, + device=device) + model = model.to(device) + state_dict = torch.load(modelpath, map_location=device) + model.load_state_dict(state_dict) + model.eval() + return model + +def calculate_activation_statistics(activations): + activations = activations.cpu().detach().numpy() + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return mu, sigma + + +def compute_features(model, iterator, device): + activations = [] + predictions = [] + with torch.no_grad(): + for i, batch in enumerate(iterator): + batch_for_model = {} + batch_for_model['x'] = batch.to(device).float() + model(batch_for_model) + activations.append(batch_for_model['features']) + predictions.append(batch_for_model['yhat']) + # labels.append(batch_for_model['y']) + activations = torch.cat(activations, dim=0) + predictions = torch.cat(predictions, dim=0) + return activations, predictions + + +def evaluate_unconstrained_metrics(generated_motions, device, fast): + + act_rec_model_path = './assets/actionrecognition/humanact12_gru_modi_struct.pth.tar' + dataset_path = './dataset/HumanAct12Poses/humanact12_modi_struct.npy' + + # initialize model + act_rec_model = initialize_model(device, act_rec_model_path) + + generated_motions -= generated_motions[:, 8:9, :, :] # locate root joint of all frames at origin + + iterator_generated = DataLoader(generated_motions, batch_size=64, shuffle=False, num_workers=8) + + # compute features of generated motions + generated_features, generated_predictions = compute_features(act_rec_model, iterator_generated, device=device) + generated_stats = calculate_activation_statistics(generated_features) + + + # dataset motions + motion_data_raw = np.load(dataset_path, allow_pickle=True) + motion_data = motion_data_raw[:, :15] # data has 16 joints for back compitability with older formats + motion_data -= motion_data[:, 8:9, :, :] # locate root joint of all frames at origin + iterator_dataset = DataLoader(motion_data, batch_size=64, shuffle=False, num_workers=8) + + # compute features of dataset motions + dataset_features, dataset_predictions = compute_features(act_rec_model, iterator_dataset, device=device) + real_stats = calculate_activation_statistics(dataset_features) + + print("evaluation resutls:\n") + + fid = calculate_fid(generated_stats, real_stats) + print(f"FID score: {fid}\n") + + print("calculating KID...") + kid = calculate_kid(dataset_features.cpu(), generated_features.cpu()) + (m, s) = kid + print('KID : %.3f (%.3f)\n' % (m, s)) + + dataset_diversity = calculate_diversity(dataset_features) + generated_diversity = calculate_diversity(generated_features) + print(f"Diversity of generated motions: {generated_diversity}") + print(f"Diversity of dataset motions: {dataset_diversity}\n") + + if fast: + print("Skipping precision-recall calculation\n") + precision = recall = None + else: + print("calculating precision recall...") + precision, recall = precision_and_recall(generated_features, dataset_features) + print(f"precision: {precision}") + print(f"recall: {recall}\n") + + metrics = {'fid': fid, 'kid': kid[0], 'diversity_gen': generated_diversity.cpu().item(), 'diversity_gt': dataset_diversity.cpu().item(), + 'precision': precision, 'recall':recall} + return metrics + diff --git a/eval/unconstrained/metrics/kid.py b/eval/unconstrained/metrics/kid.py new file mode 100644 index 00000000..f56c63f3 --- /dev/null +++ b/eval/unconstrained/metrics/kid.py @@ -0,0 +1,136 @@ +import torch +import numpy as np +from tqdm import tqdm +from sklearn.metrics.pairwise import polynomial_kernel +import sys + +# from: https://github.com/abdulfatir/gan-metrics-pytorch/blob/master/kid_score.py +def polynomial_mmd_averages(codes_g, codes_r, n_subsets=50, subset_size=1000, + ret_var=True, output=sys.stdout, **kernel_args): + m = min(codes_g.shape[0], codes_r.shape[0]) + mmds = np.zeros(n_subsets) + if ret_var: + vars = np.zeros(n_subsets) + choice = np.random.choice + + replace = subset_size < len(codes_g) + with tqdm(range(n_subsets), desc='MMD', file=output, disable=True) as bar: + for i in bar: + g = codes_g[choice(len(codes_g), subset_size, replace=replace)] + r = codes_r[choice(len(codes_r), subset_size, replace=replace)] + o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var) + if ret_var: + mmds[i], vars[i] = o + else: + mmds[i] = o + bar.set_postfix({'mean': mmds[:i+1].mean()}) + return (mmds, vars) if ret_var else mmds + +def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1, + var_at_m=None, ret_var=True): + # use k(x, y) = (gamma + coef0)^degree + # default gamma is 1 / dim + X = codes_g + Y = codes_r + + K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0) + K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0) + K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0) + + return _mmd2_and_variance(K_XX, K_XY, K_YY, + var_at_m=var_at_m, ret_var=ret_var) + +def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False, + mmd_est='unbiased', block_size=1024, + var_at_m=None, ret_var=True): + # based on + # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py + # but changed to not compute the full kernel matrix at once + m = K_XX.shape[0] + assert K_XX.shape == (m, m) + assert K_XY.shape == (m, m) + assert K_YY.shape == (m, m) + if var_at_m is None: + var_at_m = m + + # Get the various sums of kernels that we'll use + # Kts drop the diagonal, but we don't need to compute them explicitly + if unit_diagonal: + diag_X = diag_Y = 1 + sum_diag_X = sum_diag_Y = m + sum_diag2_X = sum_diag2_Y = m + else: + diag_X = np.diagonal(K_XX) + diag_Y = np.diagonal(K_YY) + + sum_diag_X = diag_X.sum() + sum_diag_Y = diag_Y.sum() + + sum_diag2_X = _sqn(diag_X) + sum_diag2_Y = _sqn(diag_Y) + + Kt_XX_sums = K_XX.sum(axis=1) - diag_X + Kt_YY_sums = K_YY.sum(axis=1) - diag_Y + K_XY_sums_0 = K_XY.sum(axis=0) + K_XY_sums_1 = K_XY.sum(axis=1) + + Kt_XX_sum = Kt_XX_sums.sum() + Kt_YY_sum = Kt_YY_sums.sum() + K_XY_sum = K_XY_sums_0.sum() + + if mmd_est == 'biased': + mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + + (Kt_YY_sum + sum_diag_Y) / (m * m) + - 2 * K_XY_sum / (m * m)) + else: + assert mmd_est in {'unbiased', 'u-statistic'} + mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1)) + if mmd_est == 'unbiased': + mmd2 -= 2 * K_XY_sum / (m * m) + else: + mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1)) + + if not ret_var: + return mmd2 + + Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X + Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y + K_XY_2_sum = _sqn(K_XY) + + dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1) + dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0) + + m1 = m - 1 + m2 = m - 2 + zeta1_est = ( + 1 / (m * m1 * m2) * ( + _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum) + - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) + + 1 / (m * m * m1) * ( + _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum) + - 2 / m**4 * K_XY_sum**2 + - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) + + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum + ) + zeta2_est = ( + 1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum) + - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) + + 2 / (m * m) * K_XY_2_sum + - 2 / m**4 * K_XY_sum**2 + - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) + + 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum + ) + var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est + + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est) + + return mmd2, var_est + + +def _sqn(arr): + flat = np.ravel(arr) + return flat.dot(flat) + +def calculate_kid(real_activations, generated_activations): + kid_values = polynomial_mmd_averages(real_activations, generated_activations, n_subsets=100) + results = (kid_values[0].mean(), kid_values[0].std()) + return results diff --git a/eval/unconstrained/metrics/precision_recall.py b/eval/unconstrained/metrics/precision_recall.py new file mode 100644 index 00000000..2edd4a91 --- /dev/null +++ b/eval/unconstrained/metrics/precision_recall.py @@ -0,0 +1,55 @@ +# based on https://github.com/blandocs/improved-precision-and-recall-metric-pytorch/blob/master/functions.py +import os, torch +import numpy as np +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm + +# self.batch_size = args.batch_size +# self.cpu = args.cpu +# self.data_size = args.data_size + +def precision_and_recall(generated_features, real_features): + k = 3 + + data_num = min(len(generated_features), len(real_features)) + print(f'data num: {data_num}') + + if data_num <= 0: + print("there is no data") + return + generated_features = generated_features[:data_num] + real_features = real_features[:data_num] + + # get precision and recall + precision = manifold_estimate(real_features, generated_features, k) + recall = manifold_estimate(generated_features, real_features, k) + + return precision, recall + +def manifold_estimate( A_features, B_features, k): + A_features = list(A_features) + B_features = list(B_features) + KNN_list_in_A = {} + for A in tqdm(A_features, ncols=80): + pairwise_distances = np.zeros(shape=(len(A_features))) + + for i, A_prime in enumerate(A_features): + d = torch.norm((A - A_prime), 2) + pairwise_distances[i] = d + + v = np.partition(pairwise_distances, k)[k] + KNN_list_in_A[A] = v + + n = 0 + + for B in tqdm(B_features, ncols=80): + for A_prime in A_features: + d = torch.norm((B - A_prime), 2) + if d <= KNN_list_in_A[A_prime]: + n += 1 + break + + return n / len(B_features) + + diff --git a/eval/unconstrained/models/stgcn.py b/eval/unconstrained/models/stgcn.py new file mode 100644 index 00000000..641d8148 --- /dev/null +++ b/eval/unconstrained/models/stgcn.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from eval.a2m.recognition.models.stgcnutils.tgcn import ConvTemporalGraphical +from eval.unconstrained.models.stgcnutils.graph import Graph + +__all__ = ["STGCN"] + + +class STGCN(nn.Module): + r"""Spatial temporal graph convolutional networks. + Args: + in_channels (int): Number of channels in the input data + num_class (int): Number of classes for the classification task + graph_args (dict): The arguments for building the graph + edge_importance_weighting (bool): If ``True``, adds a learnable + importance weighting to the edges of the graph + **kwargs (optional): Other parameters for graph convolution units + Shape: + - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})` + - Output: :math:`(N, num_class)` where + :math:`N` is a batch size, + :math:`T_{in}` is a length of input sequence, + :math:`V_{in}` is the number of graph nodes, + :math:`M_{in}` is the number of instance in a frame. + """ + + def __init__(self, in_channels, num_class, graph_args, + edge_importance_weighting, device, **kwargs): + super().__init__() + + self.device = device + self.num_class = num_class + + self.losses = ["accuracy", "cross_entropy", "mixed"] + self.criterion = torch.nn.CrossEntropyLoss(reduction='mean') + + # load graph + self.graph = Graph(**graph_args) + A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) + self.register_buffer('A', A) + + # build networks + spatial_kernel_size = A.size(0) + temporal_kernel_size = 9 + kernel_size = (temporal_kernel_size, spatial_kernel_size) + self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + # self.data_bn = nn.BatchNorm1d(in_channels * A.size(1), track_running_stats=False) + # self.data_bn = nn.InstanceNorm1d(in_channels * A.size(1)) + kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} + self.st_gcn_networks = nn.ModuleList(( + st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0), + st_gcn(64, 64, kernel_size, 1, **kwargs), + st_gcn(64, 64, kernel_size, 1, **kwargs), + # st_gcn(64, 64, kernel_size, 1, **kwargs), + st_gcn(64, 128, kernel_size, 2, **kwargs), + st_gcn(128, 128, kernel_size, 1, **kwargs), + # st_gcn(128, 128, kernel_size, 1, **kwargs), + st_gcn(128, 256, kernel_size, 2, **kwargs), + # st_gcn(256, 256, kernel_size, 1, **kwargs), + # st_gcn(256, 256, kernel_size, 1, **kwargs), + )) + + # initialize parameters for edge importance weighting + if edge_importance_weighting: + self.edge_importance = nn.ParameterList([ + nn.Parameter(torch.ones(self.A.size())) + for i in self.st_gcn_networks + ]) + else: + self.edge_importance = [1] * len(self.st_gcn_networks) + + # fcn for prediction + self.fcn = nn.Conv2d(256, num_class, kernel_size=1) + + def forward(self, batch): + # TODO: use mask + # Received batch["x"] as + # Batch(48), Joints(23), Quat(4), Time(157 + # Expecting: + # Batch, Quat:4, Time, Joints, 1 + x = batch["x"].permute(0, 2, 3, 1).unsqueeze(4).contiguous() + + # data normalization + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + # forward + for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): + x, _ = gcn(x, self.A * importance) + + # compute feature + # _, c, t, v = x.size() + # features = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) + # batch["features"] = features + + # global pooling + x = F.avg_pool2d(x, x.size()[2:]) + x = x.view(N, M, -1, 1, 1).mean(dim=1) + + # features + batch["features"] = x.squeeze() + + # prediction + x = self.fcn(x) + x = x.view(x.size(0), -1) + batch["yhat"] = x + return batch + + def compute_accuracy(self, batch): + confusion = torch.zeros(self.num_class, self.num_class, dtype=int) + yhat = batch["yhat"].max(dim=1).indices + ygt = batch["y"] + for label, pred in zip(ygt, yhat): + confusion[label][pred] += 1 + accuracy = torch.trace(confusion)/torch.sum(confusion) + return accuracy + + def compute_loss(self, batch): + cross_entropy = self.criterion(batch["yhat"], batch["y"]) + mixed_loss = cross_entropy + + acc = self.compute_accuracy(batch) + losses = {"cross_entropy": cross_entropy.item(), + "mixed": mixed_loss.item(), + "accuracy": acc.item()} + return mixed_loss, losses + + +class st_gcn(nn.Module): + r"""Applies a spatial temporal graph convolution over an input graph sequence. + Args: + in_channels (int): Number of channels in the input sequence data + out_channels (int): Number of channels produced by the convolution + kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel + stride (int, optional): Stride of the temporal convolution. Default: 1 + dropout (int, optional): Dropout rate of the final output. Default: 0 + residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True`` + Shape: + - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format + - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format + - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format + - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format + where + :math:`N` is a batch size, + :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, + :math:`T_{in}/T_{out}` is a length of input/output sequence, + :math:`V` is the number of graph nodes. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dropout=0, + residual=True): + super().__init__() + + assert len(kernel_size) == 2 + assert kernel_size[0] % 2 == 1 + padding = ((kernel_size[0] - 1) // 2, 0) + + self.gcn = ConvTemporalGraphical(in_channels, out_channels, + kernel_size[1]) + + self.tcn = nn.Sequential( + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + out_channels, + out_channels, + (kernel_size[0], 1), + (stride, 1), + padding, + ), + nn.BatchNorm2d(out_channels), + nn.Dropout(dropout, inplace=True), + ) + + if not residual: + self.residual = lambda x: 0 + + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + + else: + self.residual = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=(stride, 1)), + nn.BatchNorm2d(out_channels), + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, A): + + res = self.residual(x) + x, A = self.gcn(x, A) + x = self.tcn(x) + res + + return self.relu(x), A + + +if __name__ == "__main__": + model = STGCN(in_channels=3, num_class=60, edge_importance_weighting=True, graph_args={"layout": "smpl_noglobal", "strategy": "spatial"}) + # Batch, in_channels, time, vertices, M + inp = torch.rand(10, 3, 16, 23, 1) + out = model(inp) + print(out.shape) + import pdb + pdb.set_trace() diff --git a/eval/unconstrained/models/stgcnutils/graph.py b/eval/unconstrained/models/stgcnutils/graph.py new file mode 100644 index 00000000..8ad28036 --- /dev/null +++ b/eval/unconstrained/models/stgcnutils/graph.py @@ -0,0 +1,185 @@ +import numpy as np +import pickle as pkl + +from utils.config import SMPL_KINTREE_PATH + + +class Graph: + """ The Graph to model the skeletons extracted by the openpose + Args: + strategy (string): must be one of the follow candidates + - uniform: Uniform Labeling + - distance: Distance Partitioning + - spatial: Spatial Configuration + For more information, please refer to the section 'Partition Strategies' + in our paper (https://arxiv.org/abs/1801.07455). + layout (string): must be one of the follow candidates + - openpose: Is consists of 18 joints. For more information, please + refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output + - ntu-rgb+d: Is consists of 25 joints. For more information, please + refer to https://github.com/shahroudy/NTURGB-D + - smpl: Consists of 24/23 joints with without global rotation. + max_hop (int): the maximal distance between two connected nodes + dilation (int): controls the spacing between the kernel points + """ + + def __init__(self, + layout='openpose', + strategy='uniform', + kintree_path=SMPL_KINTREE_PATH, + max_hop=1, + dilation=1): + self.max_hop = max_hop + self.dilation = dilation + + self.kintree_path = kintree_path + + self.get_edge(layout) + self.hop_dis = get_hop_distance( + self.num_node, self.edge, max_hop=max_hop) + self.get_adjacency(strategy) + + def __str__(self): + return self.A + + def get_edge(self, layout): + if layout == 'openpose': + # self.num_node = 18 + self.num_node = 15 + self_link = [(i, i) for i in range(self.num_node)] + # neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, + # 11), + # (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), + # (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] + neighbor_link = [(4, 3), (3, 2), (2, 1), + (7, 6), (6, 5), (5, 1), + (1, 0), + (14, 13), (13, 12), (12, 8), + (11, 10), (10, 9), (9, 8), + (8, 1),] + self.edge = self_link + neighbor_link + self.center = 1 + elif layout == 'smpl': + self.num_node = 24 + self_link = [(i, i) for i in range(self.num_node)] + kt = pkl.load(open(self.kintree_path, "rb")) + neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] + self.edge = self_link + neighbor_link + self.center = 0 + elif layout == 'smpl_noglobal': + self.num_node = 23 + self_link = [(i, i) for i in range(self.num_node)] + kt = pkl.load(open(self.kintree_path, "rb")) + neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] + # remove the root joint + neighbor_1base = [n for n in neighbor_link if n[0] != 0 and n[1] != 0] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 0 + elif layout == 'ntu-rgb+d': + self.num_node = 25 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), + (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), + (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), + (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), + (22, 23), (23, 8), (24, 25), (25, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 21 - 1 + elif layout == 'ntu_edge': + self.num_node = 24 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6), + (8, 7), (9, 2), (10, 9), (11, 10), (12, 11), + (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), + (18, 17), (19, 18), (20, 19), (21, 22), (22, 8), + (23, 24), (24, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 2 + # elif layout=='customer settings' + # pass + else: + raise NotImplementedError("This Layout is not supported") + + def get_adjacency(self, strategy): + valid_hop = range(0, self.max_hop + 1, self.dilation) + adjacency = np.zeros((self.num_node, self.num_node)) + for hop in valid_hop: + adjacency[self.hop_dis == hop] = 1 + normalize_adjacency = normalize_digraph(adjacency) + + if strategy == 'uniform': + A = np.zeros((1, self.num_node, self.num_node)) + A[0] = normalize_adjacency + self.A = A + elif strategy == 'distance': + A = np.zeros((len(valid_hop), self.num_node, self.num_node)) + for i, hop in enumerate(valid_hop): + A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] + self.A = A + elif strategy == 'spatial': + A = [] + for hop in valid_hop: + a_root = np.zeros((self.num_node, self.num_node)) + a_close = np.zeros((self.num_node, self.num_node)) + a_further = np.zeros((self.num_node, self.num_node)) + for i in range(self.num_node): + for j in range(self.num_node): + if self.hop_dis[j, i] == hop: + if self.hop_dis[j, self.center] == self.hop_dis[ + i, self.center]: + a_root[j, i] = normalize_adjacency[j, i] + elif self.hop_dis[j, self. + center] > self.hop_dis[i, self. + center]: + a_close[j, i] = normalize_adjacency[j, i] + else: + a_further[j, i] = normalize_adjacency[j, i] + if hop == 0: + A.append(a_root) + else: + A.append(a_root + a_close) + A.append(a_further) + A = np.stack(A) + self.A = A + else: + raise NotImplementedError("This Strategy is not supported") + + +def get_hop_distance(num_node, edge, max_hop=1): + A = np.zeros((num_node, num_node)) + for i, j in edge: + A[j, i] = 1 + A[i, j] = 1 + + # compute hop steps + hop_dis = np.zeros((num_node, num_node)) + np.inf + transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] + arrive_mat = (np.stack(transfer_mat) > 0) + for d in range(max_hop, -1, -1): + hop_dis[arrive_mat[d]] = d + return hop_dis + + +def normalize_digraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-1) + AD = np.dot(A, Dn) + return AD + + +def normalize_undigraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-0.5) + DAD = np.dot(np.dot(Dn, A), Dn) + return DAD diff --git a/model/mdm.py b/model/mdm.py index 4739a8a7..14fd5bda 100644 --- a/model/mdm.py +++ b/model/mdm.py @@ -143,8 +143,6 @@ def forward(self, x, timesteps, y=None): x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper timesteps: [batch_size] (int) """ - assert (y is not None) == (self.cond_mode != 'no_cond' - ), "must specify y if and only if the model is class-conditional" bs, njoints, nfeats, nframes = x.shape emb = self.embed_timestep(timesteps) # [1, bs, d] diff --git a/prepare/download_recognition_unconstrained_models.sh b/prepare/download_recognition_unconstrained_models.sh new file mode 100644 index 00000000..487471d4 --- /dev/null +++ b/prepare/download_recognition_unconstrained_models.sh @@ -0,0 +1,8 @@ +mkdir -p assets/actionrecognition/ +cd assets/actionrecognition/ + +echo -e "Downloading the HumanAct12 action recognition model, adjusted for the unconstrained setting." +gdown "1xfigimkPxKt3a8zvn_ME_NAR6CyTqneK" +echo -e + +echo -e "Downloading done!" diff --git a/prepare/download_unconstrained_datasets.sh b/prepare/download_unconstrained_datasets.sh new file mode 100644 index 00000000..8abf7e92 --- /dev/null +++ b/prepare/download_unconstrained_datasets.sh @@ -0,0 +1,10 @@ +mkdir -p dataset/HumanAct12Poses +cd dataset/HumanAct12Poses + +echo "The datasets will be stored in the 'dataset' folder\n" + +# HumanAct12 poses unconstrained +echo "Downloading the HumanAct12 unconstrained poses dataset" +gdown "1KqOBTtLFgkvWSZb8ao-wdBMG7sTP3Q7d" + +echo "Downloading done!" diff --git a/sample/generate.py b/sample/generate.py index 3c93dc7f..9176d5bb 100644 --- a/sample/generate.py +++ b/sample/generate.py @@ -137,8 +137,11 @@ def main(): jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, get_rotations_back=False) - text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text' - all_text += model_kwargs['y'][text_key] + if args.unconstrained: + all_text += ['unconstrained'] * args.num_samples + else: + text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text' + all_text += model_kwargs['y'][text_key] all_motions.append(sample.cpu().numpy()) all_lengths.append(model_kwargs['y']['lengths'].cpu().numpy()) @@ -170,41 +173,75 @@ def main(): sample_files = [] num_samples_in_out_file = 7 + + sample_print_template, row_print_template, all_print_template, \ + sample_file_template, row_file_template, all_file_template = construct_template_variables(args.unconstrained) + for sample_i in range(args.num_samples): rep_files = [] for rep_i in range(args.num_repetitions): caption = all_text[rep_i*args.batch_size + sample_i] length = all_lengths[rep_i*args.batch_size + sample_i] motion = all_motions[rep_i*args.batch_size + sample_i].transpose(2, 0, 1)[:length] - save_file = 'sample{:02d}_rep{:02d}.mp4'.format(sample_i, rep_i) + save_file = sample_file_template.format(sample_i, rep_i) + print(sample_print_template.format(caption, sample_i, rep_i, save_file)) animation_save_path = os.path.join(out_path, save_file) - print(f'[({sample_i}) "{caption}" | Rep #{rep_i} | -> {save_file}]') plot_3d_motion(animation_save_path, skeleton, motion, dataset=args.dataset, title=caption, fps=fps) # Credit for visualization: https://github.com/EricGuo5513/text-to-motion - rep_files.append(animation_save_path) - all_rep_save_file = os.path.join(out_path, 'sample{:02d}.mp4'.format(sample_i)) - ffmpeg_rep_files = [f' -i {f} ' for f in rep_files] - hstack_args = f' -filter_complex hstack=inputs={args.num_repetitions}' if args.num_repetitions > 1 else '' - ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{hstack_args} {all_rep_save_file}' - os.system(ffmpeg_rep_cmd) - print(f'[({sample_i}) "{caption}" | all repetitions | -> {all_rep_save_file}]') - sample_files.append(all_rep_save_file) - - if (sample_i+1) % num_samples_in_out_file == 0 or sample_i+1 == args.num_samples: - all_sample_save_file = os.path.join(out_path, f'samples_{(sample_i - len(sample_files) + 1):02d}_to_{sample_i:02d}.mp4') - ffmpeg_rep_files = [f' -i {f} ' for f in sample_files] - vstack_args = f' -filter_complex vstack=inputs={len(sample_files)}' if len(sample_files) > 1 else '' - ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{vstack_args} {all_sample_save_file}' - os.system(ffmpeg_rep_cmd) - print(f'[(samples {(sample_i - len(sample_files) + 1):02d} to {sample_i:02d}) | all repetitions | -> {all_sample_save_file}]') - sample_files = [] + sample_files = save_multiple_samples(args, out_path, + row_print_template, all_print_template, row_file_template, all_file_template, + caption, num_samples_in_out_file, rep_files, sample_files, sample_i) abs_path = os.path.abspath(out_path) print(f'[Done] Results are at [{abs_path}]') +def save_multiple_samples(args, out_path, row_print_template, all_print_template, row_file_template, all_file_template, + caption, num_samples_in_out_file, rep_files, sample_files, sample_i): + all_rep_save_file = row_file_template.format(sample_i) + all_rep_save_path = os.path.join(out_path, all_rep_save_file) + ffmpeg_rep_files = [f' -i {f} ' for f in rep_files] + hstack_args = f' -filter_complex hstack=inputs={args.num_repetitions}' if args.num_repetitions > 1 else '' + ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{hstack_args} {all_rep_save_path}' + os.system(ffmpeg_rep_cmd) + print(row_print_template.format(caption, sample_i, all_rep_save_file)) + sample_files.append(all_rep_save_path) + if (sample_i + 1) % num_samples_in_out_file == 0 or sample_i + 1 == args.num_samples: + # all_sample_save_file = f'samples_{(sample_i - len(sample_files) + 1):02d}_to_{sample_i:02d}.mp4' + all_sample_save_file = all_file_template.format(sample_i - len(sample_files) + 1, sample_i) + all_sample_save_path = os.path.join(out_path, all_sample_save_file) + print(all_print_template.format(sample_i - len(sample_files) + 1, sample_i, all_sample_save_file)) + ffmpeg_rep_files = [f' -i {f} ' for f in sample_files] + vstack_args = f' -filter_complex vstack=inputs={len(sample_files)}' if len(sample_files) > 1 else '' + ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join( + ffmpeg_rep_files) + f'{vstack_args} {all_sample_save_path}' + os.system(ffmpeg_rep_cmd) + sample_files = [] + return sample_files + + +def construct_template_variables(unconstrained): + row_file_template = 'sample{:02d}.mp4' + all_file_template = 'samples_{:02d}_to_{:02d}.mp4' + if unconstrained: + sample_file_template = 'row{:02d}_col{:02d}.mp4' + sample_print_template = '[{} row #{:02d} column #{:02d} | -> {}]' + row_file_template = row_file_template.replace('sample', 'row') + row_print_template = '[{} row #{:02d} | all columns | -> {}]' + all_file_template = all_file_template.replace('samples', 'rows') + all_print_template = '[rows {:02d} to {:02d} | -> {}]' + else: + sample_file_template = 'sample{:02d}_rep{:02d}.mp4' + sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]' + row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]' + all_print_template = '[samples {:02d} to {:02d} | all repetitions | -> {}]' + + return sample_print_template, row_print_template, all_print_template, \ + sample_file_template, row_file_template, all_file_template + + def load_dataset(args, max_frames, n_frames): data = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, diff --git a/train/training_loop.py b/train/training_loop.py index 1590c81c..fea86e1f 100644 --- a/train/training_loop.py +++ b/train/training_loop.py @@ -164,6 +164,8 @@ def run_loop(self): self.evaluate() def evaluate(self): + if not self.args.eval_during_training: + return start_eval = time.time() if self.eval_wrapper is not None: print('Running evaluation loop: [Should take about 90 min]') @@ -187,12 +189,15 @@ def evaluate(self): elif self.dataset in ['humanact12', 'uestc']: eval_args = SimpleNamespace(num_seeds=self.args.eval_rep_times, num_samples=self.args.eval_num_samples, batch_size=self.args.eval_batch_size, device=self.device, guidance_param = 1, - dataset=self.dataset, cond_mode='action', + dataset=self.dataset, unconstrained=self.args.unconstrained, model_path=os.path.join(self.save_dir, self.ckpt_file_name())) eval_dict = eval_humanact12_uestc.evaluate(eval_args, model=self.model, diffusion=self.diffusion, data=self.data.dataset) print(f'Evaluation results on {self.dataset}: {sorted(eval_dict["feats"].items())}') for k, v in eval_dict["feats"].items(): - self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval') + if 'unconstrained' not in k: + self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval') + else: + self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval Unconstrained') end_eval = time.time() print(f'Evaluation time: {round(end_eval-start_eval)/60}min') diff --git a/utils/model_util.py b/utils/model_util.py index f4ef3516..fd697b07 100644 --- a/utils/model_util.py +++ b/utils/model_util.py @@ -20,7 +20,12 @@ def get_model_args(args, data): # default args clip_version = 'ViT-B/32' action_emb = 'tensor' - cond_mode = 'text' if args.dataset in ['kit', 'humanml'] else 'action' + if args.unconstrained: + cond_mode = 'no_cond' + elif args.dataset in ['kit', 'humanml']: + cond_mode = 'text' + else: + cond_mode = 'action' if hasattr(data.dataset, 'num_actions'): num_actions = data.dataset.num_actions else: diff --git a/utils/parser_util.py b/utils/parser_util.py index 3514a0a9..bdd3b6c4 100644 --- a/utils/parser_util.py +++ b/utils/parser_util.py @@ -21,11 +21,17 @@ def parse_and_load_from_model(parser): assert os.path.exists(args_path), 'Arguments json file was not found!' with open(args_path, 'r') as fr: model_args = json.load(fr) + for a in args_to_overwrite: if a in model_args.keys(): - args.__dict__[a] = model_args[a] + setattr(args, a, model_args[a]) + + elif 'cond_mode' in model_args and model_args['cond_mode'] == 'no_cond': # backward compitability + setattr(args, 'unconstrained', True) + else: print('Warning: was not able to load [{}], using default value [{}] instead.'.format(a, args.__dict__[a])) + if args.cond_mask_prob == 0: args.guidance_param = 1 return args @@ -83,6 +89,9 @@ def add_model_options(parser): group.add_argument("--lambda_rcxyz", default=0.0, type=float, help="Joint positions loss.") group.add_argument("--lambda_vel", default=0.0, type=float, help="Joint velocity loss.") group.add_argument("--lambda_fc", default=0.0, type=float, help="Foot contact loss.") + group.add_argument("--unconstrained", action='store_true', + help="Model is trained unconditionally. That is, it is constrained by neither text nor action. " + "Currently tested on HumanAct12 only.")