Skip to content

Commit

Permalink
Update Predict to optionally output motion parameters with Replicate API
Browse files Browse the repository at this point in the history
  • Loading branch information
evanatherton committed Jan 31, 2023
1 parent db5478a commit bfe8ea6
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 7 deletions.
30 changes: 23 additions & 7 deletions sample/predict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import subprocess
import typing
from typing import Any, List, Optional
from argparse import Namespace

import torch
from cog import BasePredictor, Input, Path
from cog import BasePredictor, Input, Path, BaseModel

import data_loaders.humanml.utils.paramUtil as paramUtil
from data_loaders.get_data import get_dataset_loader
Expand All @@ -14,6 +14,7 @@
from model.cfg_sampler import ClassifierFreeSampleModel
from utils import dist_util
from utils.model_util import create_model_and_diffusion, load_model_wo_clip
from visualize.motions2hik import motions2hik
from sample.generate import construct_template_variables

"""
Expand All @@ -22,6 +23,11 @@
"""


class ModelOutput(BaseModel):
json_file: Optional[Any]
animation: Optional[List[Path]]


def get_args():
args = Namespace()
args.fps = 20
Expand Down Expand Up @@ -78,8 +84,16 @@ def predict(
self,
prompt: str = Input(default="the person walked forward and is picking up his toolbox."),
num_repetitions: int = Input(default=3, description="How many"),

) -> typing.List[Path]:
output_format: str = Input(
description='Choose the format of the output, either an animation or a json file of the animation data.\
The json format is: {"thetas": [...], "root_translation": [...], "joint_map": [...]}, where "thetas" \
is an [nframes x njoints x 3] array of joint rotations in degrees, "root_translation" is an [nframes x 3] \
array of (X, Y, Z) positions of the root, and "joint_map" is a list mapping the SMPL joint index to the\
corresponding HumanIK joint name',
default="animation",
choices=["animation", "json_file"],
),
) -> ModelOutput:
args = self.args
args.num_repetitions = int(num_repetitions)

Expand Down Expand Up @@ -126,11 +140,14 @@ def predict(

all_motions = sample.cpu().numpy()

if output_format == 'json_file':
data_dict = motions2hik(all_motions)
return ModelOutput(json_file=data_dict)

caption = str(prompt)

skeleton = paramUtil.t2m_kinematic_chain


sample_print_template, row_print_template, all_print_template, \
sample_file_template, row_file_template, all_file_template = construct_template_variables(
args.unconstrained)
Expand All @@ -147,5 +164,4 @@ def predict(

replicate_fnames.append(Path(save_file))

return replicate_fnames

return ModelOutput(animation=replicate_fnames)
103 changes: 103 additions & 0 deletions visualize/motions2hik.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np
import torch

from utils.rotation_conversions import rotation_6d_to_matrix, matrix_to_euler_angles
from visualize.simplify_loc2rot import joints2smpl

"""
Utility function to convert model output to a representation used by HumanIK skeletons in Maya and Motion Builder
by converting joint positions to joint rotations in degrees. Based on visualize.vis_utils.npy2obj
"""

# Mapping of SMPL joint index to HIK joint Name
JOINT_MAP = [
'Hips',
'LeftUpLeg',
'RightUpLeg',
'Spine',
'LeftLeg',
'RightLeg',
'Spine1',
'LeftFoot',
'RightFoot',
'Spine2',
'LeftToeBase',
'RightToeBase',
'Neck',
'LeftShoulder',
'RightShoulder',
'Head',
'LeftArm',
'RightArm',
'LeftForeArm',
'RightForeArm',
'LeftHand',
'RightHand'
]


def motions2hik(motions, device=0, cuda=True):
"""
Utility function to convert model output to a representation used by HumanIK skeletons in Maya and Motion Builder
by converting joint positions to joint rotations in degrees. Based on visualize.vis_utils.npy2obj
:param motions: numpy array containing MDM model output [num_reps, num_joints, num_params (xyz), num_frames
:param device:
:param cuda:
:returns: JSON serializable dict to be used with the Replicate API implementation
"""

nreps, njoints, nfeats, nframes = motions.shape
j2s = joints2smpl(num_frames=nframes, device_id=device, cuda=cuda)

thetas = []
root_translation = []
for rep_idx in range(nreps):
rep_motions = motions[rep_idx].transpose(2, 0, 1) # [nframes, njoints, 3]

if nfeats == 3:
print(f'Running SMPLify for repetition [{rep_idx + 1}] of {nreps}, it may take a few minutes.')
motion_tensor, opt_dict = j2s.joint2smpl(rep_motions) # [nframes, njoints, 3]
motion = motion_tensor.cpu().numpy()

elif nfeats == 6:
motion = rep_motions
thetas.append(rep_motions)

# Convert 6D rotation representation to Euler angles
thetas_6d = motion[0, :-1, :, :nframes].transpose(2, 0, 1) # [nframes, njoints, 6]
thetas_deg = []
for frame, d6 in enumerate(thetas_6d):
thetas_deg.append([_rotation_6d_to_euler(d6)])

thetas.append([np.concatenate(thetas_deg, axis=0)])
root_translation.append([motion[0, -1, :3, :nframes].transpose(1, 0)]) # [nframes, 3]

thetas = np.concatenate(thetas, axis=0)[:nframes]
root_translation = np.concatenate(root_translation, axis=0)[:nframes]

data_dict = {
'joint_map': JOINT_MAP,
'thetas': thetas.tolist(), # [nreps, nframes, njoints, 3 (deg)]
'root_translation': root_translation.tolist(), # [nreps, nframes, 3 (xyz)]
}

return data_dict


def _rotation_6d_to_euler(d6):
"""
Converts 6D rotation representation by Zhou et al. [1] to euler angles
using Gram--Schmidt orthogonalisation per Section B of [1].
:param d6: numpy Array 6D rotation representation, of size (*, 6)
:returns: JSON serializable dict to be used with the Replicate API implementation
:returns: euler angles in degrees as a numpy array with shape (*, 3)
"""
rot_mat = rotation_6d_to_matrix(torch.tensor(d6))
rot_eul_rad = matrix_to_euler_angles(rot_mat, 'XYZ')
eul_deg = torch.rad2deg(rot_eul_rad).numpy()

return eul_deg

0 comments on commit bfe8ea6

Please sign in to comment.