forked from GuyTevet/motion-diffusion-model
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added evaluation for humanml and kit
- Loading branch information
Showing
22 changed files
with
2,464 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
262 changes: 262 additions & 0 deletions
262
data_loaders/humanml/motion_loaders/comp_v6_model_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
import torch | ||
from data_loaders.humanml.networks.modules import * | ||
from data_loaders.humanml.networks.trainers import CompTrainerV6 | ||
from torch.utils.data import Dataset, DataLoader | ||
from os.path import join as pjoin | ||
from tqdm import tqdm | ||
from utils import dist_util | ||
|
||
def build_models(opt): | ||
if opt.text_enc_mod == 'bigru': | ||
text_encoder = TextEncoderBiGRU(word_size=opt.dim_word, | ||
pos_size=opt.dim_pos_ohot, | ||
hidden_size=opt.dim_text_hidden, | ||
device=opt.device) | ||
text_size = opt.dim_text_hidden * 2 | ||
else: | ||
raise Exception("Text Encoder Mode not Recognized!!!") | ||
|
||
seq_prior = TextDecoder(text_size=text_size, | ||
input_size=opt.dim_att_vec + opt.dim_movement_latent, | ||
output_size=opt.dim_z, | ||
hidden_size=opt.dim_pri_hidden, | ||
n_layers=opt.n_layers_pri) | ||
|
||
|
||
seq_decoder = TextVAEDecoder(text_size=text_size, | ||
input_size=opt.dim_att_vec + opt.dim_z + opt.dim_movement_latent, | ||
output_size=opt.dim_movement_latent, | ||
hidden_size=opt.dim_dec_hidden, | ||
n_layers=opt.n_layers_dec) | ||
|
||
att_layer = AttLayer(query_dim=opt.dim_pos_hidden, | ||
key_dim=text_size, | ||
value_dim=opt.dim_att_vec) | ||
|
||
movement_enc = MovementConvEncoder(opt.dim_pose - 4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) | ||
movement_dec = MovementConvDecoder(opt.dim_movement_latent, opt.dim_movement_dec_hidden, opt.dim_pose) | ||
|
||
len_estimator = MotionLenEstimatorBiGRU(opt.dim_word, opt.dim_pos_ohot, 512, opt.num_classes) | ||
|
||
# latent_dis = LatentDis(input_size=opt.dim_z * 2) | ||
checkpoints = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_est_bigru', 'model', 'latest.tar'), map_location=opt.device) | ||
len_estimator.load_state_dict(checkpoints['estimator']) | ||
len_estimator.to(opt.device) | ||
len_estimator.eval() | ||
|
||
# return text_encoder, text_decoder, att_layer, vae_pri, vae_dec, vae_pos, motion_dis, movement_dis, latent_dis | ||
return text_encoder, seq_prior, seq_decoder, att_layer, movement_enc, movement_dec, len_estimator | ||
|
||
class CompV6GeneratedDataset(Dataset): | ||
|
||
def __init__(self, opt, dataset, w_vectorizer, mm_num_samples, mm_num_repeats): | ||
assert mm_num_samples < len(dataset) | ||
print(opt.model_dir) | ||
|
||
dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True) | ||
text_enc, seq_pri, seq_dec, att_layer, mov_enc, mov_dec, len_estimator = build_models(opt) | ||
trainer = CompTrainerV6(opt, text_enc, seq_pri, seq_dec, att_layer, mov_dec, mov_enc=mov_enc) | ||
epoch, it, sub_ep, schedule_len = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar')) | ||
generated_motion = [] | ||
mm_generated_motions = [] | ||
mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False) | ||
mm_idxs = np.sort(mm_idxs) | ||
min_mov_length = 10 if opt.dataset_name == 't2m' else 6 | ||
# print(mm_idxs) | ||
|
||
print('Loading model: Epoch %03d Schedule_len %03d' % (epoch, schedule_len)) | ||
trainer.eval_mode() | ||
trainer.to(opt.device) | ||
with torch.no_grad(): | ||
for i, data in tqdm(enumerate(dataloader)): | ||
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data | ||
tokens = tokens[0].split('_') | ||
word_emb = word_emb.detach().to(opt.device).float() | ||
pos_ohot = pos_ohot.detach().to(opt.device).float() | ||
|
||
pred_dis = len_estimator(word_emb, pos_ohot, cap_lens) | ||
pred_dis = nn.Softmax(-1)(pred_dis).squeeze() | ||
|
||
mm_num_now = len(mm_generated_motions) | ||
is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False | ||
|
||
repeat_times = mm_num_repeats if is_mm else 1 | ||
mm_motions = [] | ||
for t in range(repeat_times): | ||
mov_length = torch.multinomial(pred_dis, 1, replacement=True) | ||
if mov_length < min_mov_length: | ||
mov_length = torch.multinomial(pred_dis, 1, replacement=True) | ||
if mov_length < min_mov_length: | ||
mov_length = torch.multinomial(pred_dis, 1, replacement=True) | ||
|
||
m_lens = mov_length * opt.unit_length | ||
pred_motions, _, _ = trainer.generate(word_emb, pos_ohot, cap_lens, m_lens, | ||
m_lens[0]//opt.unit_length, opt.dim_pose) | ||
if t == 0: | ||
# print(m_lens) | ||
# print(text_data) | ||
sub_dict = {'motion': pred_motions[0].cpu().numpy(), | ||
'length': m_lens[0].item(), | ||
'cap_len': cap_lens[0].item(), | ||
'caption': caption[0], | ||
'tokens': tokens} | ||
generated_motion.append(sub_dict) | ||
|
||
if is_mm: | ||
mm_motions.append({ | ||
'motion': pred_motions[0].cpu().numpy(), | ||
'length': m_lens[0].item() | ||
}) | ||
if is_mm: | ||
mm_generated_motions.append({'caption': caption[0], | ||
'tokens': tokens, | ||
'cap_len': cap_lens[0].item(), | ||
'mm_motions': mm_motions}) | ||
|
||
self.generated_motion = generated_motion | ||
self.mm_generated_motion = mm_generated_motions | ||
self.opt = opt | ||
self.w_vectorizer = w_vectorizer | ||
|
||
|
||
def __len__(self): | ||
return len(self.generated_motion) | ||
|
||
|
||
def __getitem__(self, item): | ||
data = self.generated_motion[item] | ||
motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens'] | ||
sent_len = data['cap_len'] | ||
|
||
pos_one_hots = [] | ||
word_embeddings = [] | ||
for token in tokens: | ||
word_emb, pos_oh = self.w_vectorizer[token] | ||
pos_one_hots.append(pos_oh[None, :]) | ||
word_embeddings.append(word_emb[None, :]) | ||
pos_one_hots = np.concatenate(pos_one_hots, axis=0) | ||
word_embeddings = np.concatenate(word_embeddings, axis=0) | ||
|
||
if m_length < self.opt.max_motion_length: | ||
motion = np.concatenate([motion, | ||
np.zeros((self.opt.max_motion_length - m_length, motion.shape[1])) | ||
], axis=0) | ||
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) | ||
|
||
class CompMDMGeneratedDataset(Dataset): | ||
|
||
def __init__(self, model, diffusion, dataloader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale=None): | ||
self.dataloader = dataloader | ||
self.dataset = dataloader.dataset | ||
assert mm_num_samples < len(dataloader.dataset) | ||
use_ddim = False # FIXME - hardcoded | ||
clip_denoised = False # FIXME - hardcoded | ||
self.max_motion_length = max_motion_length | ||
sample_fn = ( | ||
diffusion.p_sample_loop if not use_ddim else diffusion.ddim_sample_loop | ||
) | ||
|
||
real_num_batches = len(dataloader) | ||
if num_samples_limit is not None: | ||
real_num_batches = num_samples_limit // dataloader.batch_size + 1 | ||
print('real_num_batches', real_num_batches) | ||
|
||
generated_motion = [] | ||
mm_generated_motions = [] | ||
if mm_num_samples > 0: | ||
mm_idxs = np.random.choice(real_num_batches, mm_num_samples // dataloader.batch_size +1, replace=False) | ||
mm_idxs = np.sort(mm_idxs) | ||
else: | ||
mm_idxs = [] | ||
print('mm_idxs', mm_idxs) | ||
|
||
model.eval() | ||
|
||
|
||
with torch.no_grad(): | ||
for i, (motion, model_kwargs) in tqdm(enumerate(dataloader)): | ||
|
||
if num_samples_limit is not None and len(generated_motion) >= num_samples_limit: | ||
break | ||
|
||
tokens = [t.split('_') for t in model_kwargs['y']['tokens']] | ||
|
||
# add CFG scale to batch | ||
if scale != 1.: | ||
model_kwargs['y']['scale'] = torch.ones(motion.shape[0], | ||
device=dist_util.dev()) * scale | ||
|
||
mm_num_now = len(mm_generated_motions) // dataloader.batch_size | ||
is_mm = i in mm_idxs | ||
repeat_times = mm_num_repeats if is_mm else 1 | ||
mm_motions = [] | ||
for t in range(repeat_times): | ||
|
||
sample = sample_fn( | ||
model, | ||
motion.shape, | ||
clip_denoised=clip_denoised, | ||
model_kwargs=model_kwargs, | ||
skip_timesteps=0, # 0 is the default value - i.e. don't skip any step | ||
init_image=None, | ||
progress=False, | ||
dump_steps=None, | ||
noise=None, | ||
const_noise=False, | ||
# when experimenting guidance_scale we want to nutrileze the effect of noise on generation | ||
) | ||
|
||
if t == 0: | ||
sub_dicts = [{'motion': sample[bs_i].squeeze().permute(1,0).cpu().numpy(), | ||
'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(), | ||
'caption': model_kwargs['y']['text'][bs_i], | ||
'tokens': tokens[bs_i], | ||
'cap_len': len(tokens[bs_i]), | ||
} for bs_i in range(dataloader.batch_size)] | ||
generated_motion += sub_dicts | ||
|
||
if is_mm: | ||
mm_motions += [{'motion': sample[bs_i].squeeze().permute(1, 0).cpu().numpy(), | ||
'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(), | ||
} for bs_i in range(dataloader.batch_size)] | ||
|
||
if is_mm: | ||
mm_generated_motions += [{ | ||
'caption': model_kwargs['y']['text'][bs_i], | ||
'tokens': tokens[bs_i], | ||
'cap_len': len(tokens[bs_i]), | ||
'mm_motions': mm_motions[bs_i::dataloader.batch_size], # collect all 10 repeats from the (32*10) generated motions | ||
} for bs_i in range(dataloader.batch_size)] | ||
|
||
|
||
self.generated_motion = generated_motion | ||
self.mm_generated_motion = mm_generated_motions | ||
self.w_vectorizer = dataloader.dataset.w_vectorizer | ||
|
||
|
||
def __len__(self): | ||
return len(self.generated_motion) | ||
|
||
|
||
def __getitem__(self, item): | ||
data = self.generated_motion[item] | ||
motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens'] | ||
sent_len = data['cap_len'] | ||
|
||
if self.dataset.mode == 'eval': | ||
normed_motion = motion | ||
denormed_motion = self.dataset.t2m_dataset.inv_transform(normed_motion) | ||
renormed_motion = (denormed_motion - self.dataset.mean_for_eval) / self.dataset.std_for_eval # according to T2M norms | ||
motion = renormed_motion | ||
# This step is needed because T2M evaluators expect their norm convention | ||
|
||
pos_one_hots = [] | ||
word_embeddings = [] | ||
for token in tokens: | ||
word_emb, pos_oh = self.w_vectorizer[token] | ||
pos_one_hots.append(pos_oh[None, :]) | ||
word_embeddings.append(word_emb[None, :]) | ||
pos_one_hots = np.concatenate(pos_one_hots, axis=0) | ||
word_embeddings = np.concatenate(word_embeddings, axis=0) | ||
|
||
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) |
27 changes: 27 additions & 0 deletions
27
data_loaders/humanml/motion_loaders/dataset_motion_loader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from t2m.data.dataset import Text2MotionDatasetV2, collate_fn | ||
from t2m.utils.word_vectorizer import WordVectorizer | ||
import numpy as np | ||
from os.path import join as pjoin | ||
from torch.utils.data import DataLoader | ||
from t2m.utils.get_opt import get_opt | ||
|
||
def get_dataset_motion_loader(opt_path, batch_size, device): | ||
opt = get_opt(opt_path, device) | ||
|
||
# Configurations of T2M dataset and KIT dataset is almost the same | ||
if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': | ||
print('Loading dataset %s ...' % opt.dataset_name) | ||
|
||
mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) | ||
std = np.load(pjoin(opt.meta_dir, 'std.npy')) | ||
|
||
w_vectorizer = WordVectorizer('./glove', 'our_vab') | ||
split_file = pjoin(opt.data_root, 'test.txt') | ||
dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer) | ||
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True, | ||
collate_fn=collate_fn, shuffle=True) | ||
else: | ||
raise KeyError('Dataset not Recognized !!') | ||
|
||
print('Ground Truth Dataset Loading Completed!!!') | ||
return dataloader, dataset |
Oops, something went wrong.