Skip to content

Commit

Permalink
support unconstrained training, sampling and evaluating
Browse files Browse the repository at this point in the history
  • Loading branch information
sigal-raab committed Nov 4, 2022
1 parent f00e01d commit fc1439f
Show file tree
Hide file tree
Showing 17 changed files with 921 additions and 49 deletions.
75 changes: 66 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ If you find this code useful in your research, please cite:

## News

📢 **4/Nov/22** - Added sampling, training and evaluation of unconstrained tasks.
Note slight env changes adapting to the new code. If you already have an installed environment, run `bash prepare/download_unconstrained_assets.sh; conda install -y -c anaconda scikit-learn
` to adapt.

📢 **3/Nov/22** - Added in-between and upper-body editing.

📢 **31/Oct/22** - Added sampling, training and evaluation of action-to-motion tasks.
Expand All @@ -33,10 +37,6 @@ If you find this code useful in your research, please cite:

📢 **6/Oct/22** - First release - sampling and rendering using pre-trained models.

## ETAs

* Unconstrained Motion: Nov 22


## Getting started

Expand Down Expand Up @@ -76,15 +76,23 @@ bash prepare/download_glove.sh
</details>

<details>
<summary><b>Text to Motion, Unconstrained</b></summary>
<summary><b>Action to Motion</b></summary>

```bash
bash prepare/download_smpl_files.sh
bash prepare/download_a2m_datasets.sh
bash prepare/download_recognition_models.sh
```
</details>

<details>
<summary><b>Unconstrained</b></summary>

```bash
bash prepare/download_smpl_files.sh
bash prepare/download_recognition_unconstrained_models.sh
```
</details>

### 2. Get data

<details>
Expand Down Expand Up @@ -125,12 +133,21 @@ cp -r ../HumanML3D/HumanML3D ./dataset/HumanML3D
<details>
<summary><b>Action to Motion</b></summary>

**UESTC, HumanAct12** :
**UESTC, HumanAct12**
```bash
bash prepare/download_a2m_datasets.sh
```
</details>

<details>
<summary><b>Unconstrained</b></summary>

**HumanAct12**
```bash
bash prepare/download_unconstrained_datasets.sh
```
</details>

### 3. Download the pretrained models

Download the model(s) you wish to use, then unzip and place them in `./save/`.
Expand Down Expand Up @@ -171,6 +188,15 @@ Download the model(s) you wish to use, then unzip and place them in `./save/`.

</details>

<details>
<summary><b>Unconstrained</b></summary>

**HumanAct12**

[humanact12_unconstrained](https://drive.google.com/file/d/1uG68m200pZK3pD-zTmPXu5XkgNpx_mEx/view?usp=share_link)

</details>


## Motion Synthesis
<details>
Expand Down Expand Up @@ -217,6 +243,16 @@ python -m sample.generate --model_path ./save/humanact12/model000350000.pt --tex
```
</details>

<details>
<summary><b>Unconstrained</b></summary>

```shell
python -m sample.generate --model_path ./save/unconstrained/model000450000.pt --num_samples 10 --num_repetitions 3
```

By abuse of notation, (num_samples * num_repetitions) samples are created, and are visually organized in a display of num_samples rows and num_repetitions columns.

</details>

**You may also define:**
* `--device` id.
Expand Down Expand Up @@ -317,6 +353,14 @@ python -m train.train_mdm --save_dir save/my_name --dataset {humanact12,uestc} -
```
</details>

<details>
<summary><b>Unconstrained</b></summary>

```shell
python -m train.train_mdm --save_dir save/my_name --dataset humanact12 --cond_mask_prob 0 --lambda_rcxyz 1 --lambda_vel 1 --lambda_fc 1 --unconstrained
```
</details>

* Use `--device` to define GPU id.
* Use `--arch` to choose one of the architectures reported in the paper `{trans_enc, trans_dec, gru}` (`trans_enc` is default).
* Add `--train_platform_type {ClearmlPlatform, TensorboardPlatform}` to track results with either [ClearML](https://clear.ml/) or [Tensorboard](https://www.tensorflow.org/tensorboard).
Expand Down Expand Up @@ -349,19 +393,32 @@ python -m eval.eval_humanml --model_path ./save/kit_trans_enc_512/model000400000
* The output of this script for the pre-trained models (as was reported in the paper) is provided in the checkpoints zip file.

```shell
--model <path-to-model-ckpt> --eval_mode full
python -m eval.eval_humanact12_uestc --model <path-to-model-ckpt> --eval_mode full
```
where `path-to-model-ckpt` can be a path to any of the pretrained action-to-motion models listed above, or to a checkpoint trained by the user.

</details>


<details>
<summary><b>Unconstrained</b></summary>

* Takes about 3 hours (on a single GPU)

```shell
python -m eval.eval_humanact12_uestc --model ./save/unconstrained/model000450000.pt --eval_mode full
```

Precision and recall are not computed to save computing time. If you wish to compute them, edit the file eval/a2m/gru_eval.py and change the string `fast=True` to `fast=False`.

</details>

## Acknowledgments

This code is standing on the shoulders of giants. We want to thank the following contributors
that our code is based on:

[guided-diffusion](https://github.com/openai/guided-diffusion), [MotionCLIP](https://github.com/GuyTevet/MotionCLIP), [text-to-motion](https://github.com/EricGuo5513/text-to-motion), [actor](https://github.com/Mathux/ACTOR), [joints2smpl](https://github.com/wangsen1312/joints2smpl).
[guided-diffusion](https://github.com/openai/guided-diffusion), [MotionCLIP](https://github.com/GuyTevet/MotionCLIP), [text-to-motion](https://github.com/EricGuo5513/text-to-motion), [actor](https://github.com/Mathux/ACTOR), [joints2smpl](https://github.com/wangsen1312/joints2smpl), [MoDi](https://github.com/sigal-raab/MoDi).

## License
This code is distributed under an [MIT LICENSE](LICENSE).
Expand Down
10 changes: 6 additions & 4 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: mdm
channels:
- pytorch
- anaconda
- conda-forge
- defaults
dependencies:
Expand All @@ -9,9 +10,9 @@ dependencies:
- beautifulsoup4=4.11.1=pyha770c72_0
- blas=1.0=mkl
- brotlipy=0.7.0=py37h540881e_1004
- ca-certificates=2022.9.24=ha878542_0
- ca-certificates=2022.07.19=h06a4308_0
- catalogue=2.0.8=py37h89c1867_0
- certifi=2022.9.24=pyhd8ed1ab_0
- certifi=2022.6.15=py37h06a4308_0
- cffi=1.15.1=py37h74dc2b5_0
- charset-normalizer=2.1.1=pyhd8ed1ab_0
- colorama=0.4.5=pyhd8ed1ab_0
Expand All @@ -37,6 +38,7 @@ dependencies:
- idna=3.4=pyhd8ed1ab_0
- intel-openmp=2021.4.0=h06a4308_3561
- jinja2=3.1.2=pyhd8ed1ab_1
- joblib=1.1.0=pyhd3eb1b0_0
- jpeg=9b=h024ee3a_2
- kiwisolver=1.4.2=py37h295c915_0
- langcodes=3.3.0=pyhd8ed1ab_0
Expand Down Expand Up @@ -87,6 +89,7 @@ dependencies:
- qt=5.9.7=h5867ecd_1
- readline=8.1.2=h7f8727e_1
- requests=2.28.1=pyhd8ed1ab_1
- scikit-learn=1.0.2=py37h51133e4_1
- scipy=1.7.3=py37h6c91a56_2
- setuptools=63.4.1=py37h06a4308_0
- shellingham=1.5.0=pyhd8ed1ab_0
Expand All @@ -98,6 +101,7 @@ dependencies:
- spacy-legacy=3.0.10=pyhd8ed1ab_0
- spacy-loggers=1.0.3=pyhd8ed1ab_0
- sqlite=3.39.3=h5082296_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=0.7.2=py37
- torchvision=0.8.2=py37_cu110
Expand All @@ -113,10 +117,8 @@ dependencies:
- pip:
- blis==0.7.8
- chumpy==0.70
- clearml==1.7.1
- click==8.1.3
- confection==0.0.2
- filelock==3.8.0
- ftfy==6.1.1
- importlib-metadata==5.0.0
- lxml==4.9.1
Expand Down
15 changes: 15 additions & 0 deletions eval/a2m/action2motion/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@
import numpy as np


#adapted from action2motion
def calculate_diversity(activations):
diversity_times = 200
num_motions = len(activations)

diversity = 0

first_indices = np.random.randint(0, num_motions, diversity_times)
second_indices = np.random.randint(0, num_motions, diversity_times)
for first_idx, second_idx in zip(first_indices, second_indices):
diversity += torch.dist(activations[first_idx, :],
activations[second_idx, :])
diversity /= diversity_times
return diversity

# from action2motion
def calculate_diversity_multimodality(activations, labels, num_labels, unconstrained = False):
diversity_times = 200
Expand Down
34 changes: 27 additions & 7 deletions eval/a2m/gru_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import os

import numpy as np
from tqdm import tqdm
import torch
import functools
Expand All @@ -8,13 +10,14 @@
from utils.fixseed import fixseed
from data_loaders.tensors import collate
from eval.a2m.action2motion.evaluate import A2MEvaluation
from eval.unconstrained.evaluate import evaluate_unconstrained_metrics
from .tools import save_metrics, format_metrics
from data_loaders.get_data import get_dataset
from utils import dist_util

num_samples_unconstrained = 1000

class NewDataloader:
def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, dataset, num_samples: int=-1):
def __init__(self, mode, model, diffusion, dataiterator, device, unconstrained, num_samples: int=-1):
assert mode in ["gen", "gt"]
self.batches = []
sample_fn = diffusion.p_sample_loop
Expand All @@ -37,7 +40,7 @@ def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, data
translation=True, jointstype='smpl', vertstrans=True, betas=None,
beta=0, glob_rot=None, get_rotations_back=False)
batch["lengths"] = model_kwargs['y']['lengths'].to(device)
if cond_mode != 'no_cond': # proceed only if not running unconstrained
if unconstrained: # proceed only if not running unconstrained
batch["y"] = model_kwargs['y']['action'].squeeze().long().cpu() # using torch.long so lengths/action will be used as indices
self.batches.append(batch)

Expand All @@ -49,7 +52,6 @@ def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, data
def __iter__(self):
return iter(self.batches)


def evaluate(args, model, diffusion, data):
num_frames = 60

Expand Down Expand Up @@ -87,25 +89,43 @@ def evaluate(args, model, diffusion, data):
shuffle=False, num_workers=8, collate_fn=collate)

new_data_loader = functools.partial(NewDataloader, model=model, diffusion=diffusion, device=device,
cond_mode=args.cond_mode, dataset=args.dataset,
num_samples=args.num_samples)
unconstrained=args.unconstrained, num_samples=args.num_samples)
motionloader = new_data_loader(mode="gen", dataiterator=dataiterator)
gt_motionloader = new_data_loader("gt", dataiterator=dataiterator)
gt_motionloader2 = new_data_loader("gt", dataiterator=dataiterator2)

# Action2motionEvaluation
loaders = {"gen": motionloader,
# "recons": reconstructedloader,
"gt": gt_motionloader,
"gt2": gt_motionloader2}

a2mmetrics[seed] = a2mevaluation.evaluate(model, loaders)

del loaders

if args.unconstrained: # unconstrained
dataset_unconstrained = copy.deepcopy(data)
dataset_unconstrained.reset_shuffle()
dataset_unconstrained.shuffle()
dataiterator_unconstrained = DataLoader(dataset_unconstrained, batch_size=args.batch_size,
shuffle=False, num_workers=8, collate_fn=collate)
motionloader_unconstrained = new_data_loader(mode="gen", dataiterator=dataiterator_unconstrained, num_samples=num_samples_unconstrained)

generated_motions = []
for motion in motionloader_unconstrained:
idx = [15, 12, 16, 18, 20, 17, 19, 21, 0, 1, 4, 7, 2, 5, 8]
motion = motion['output_xyz'][:, idx, :, :]
generated_motions.append(motion.cpu().numpy())
generated_motions = np.concatenate(generated_motions)
unconstrained_metrics = evaluate_unconstrained_metrics(generated_motions, device, fast=True)
unconstrained_metrics = {k+'_unconstrained': v for k, v in unconstrained_metrics.items()}

except KeyboardInterrupt:
string = "Saving the evaluation before exiting.."
print(string)

metrics = {"feats": {key: [format_metrics(a2mmetrics[seed])[key] for seed in a2mmetrics.keys()] for key in a2mmetrics[allseeds[0]]}}
if args.unconstrained:
metrics["feats"] = {**metrics["feats"], **unconstrained_metrics}

return metrics
2 changes: 0 additions & 2 deletions eval/eval_humanact12_uestc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def main():
else:
args.num_samples = 1000
args.num_seeds = 20
args.cond_mode = 'action' # temporary code till 'unconstrained' is implemented


data_loader = get_dataset_loader(name=args.dataset, num_frames=60, batch_size=args.batch_size,)

Expand Down
Loading

0 comments on commit fc1439f

Please sign in to comment.