forked from GuyTevet/motion-diffusion-model
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bb8960a
commit ce26c8e
Showing
26 changed files
with
1,474 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import torch | ||
|
||
|
||
def calculate_accuracy(model, motion_loader, num_labels, classifier, device): | ||
confusion = torch.zeros(num_labels, num_labels, dtype=torch.long) | ||
with torch.no_grad(): | ||
for batch in motion_loader: | ||
batch_prob = classifier(batch["output_xyz"], lengths=batch["lengths"]) | ||
batch_pred = batch_prob.max(dim=1).indices | ||
for label, pred in zip(batch["y"], batch_pred): | ||
confusion[label][pred] += 1 | ||
|
||
accuracy = torch.trace(confusion)/torch.sum(confusion) | ||
return accuracy.item(), confusion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
import numpy as np | ||
|
||
|
||
# from action2motion | ||
def calculate_diversity_multimodality(activations, labels, num_labels, unconstrained = False): | ||
diversity_times = 200 | ||
multimodality_times = 20 | ||
if not unconstrained: | ||
labels = labels.long() | ||
num_motions = activations.shape[0] # len(labels) | ||
|
||
diversity = 0 | ||
|
||
first_indices = np.random.randint(0, num_motions, diversity_times) | ||
second_indices = np.random.randint(0, num_motions, diversity_times) | ||
for first_idx, second_idx in zip(first_indices, second_indices): | ||
diversity += torch.dist(activations[first_idx, :], | ||
activations[second_idx, :]) | ||
diversity /= diversity_times | ||
|
||
if not unconstrained: | ||
multimodality = 0 | ||
label_quotas = np.zeros(num_labels) | ||
label_quotas[labels.unique()] = multimodality_times # if a label does not appear in batch, its quota remains zero | ||
while np.any(label_quotas > 0): | ||
# print(label_quotas) | ||
first_idx = np.random.randint(0, num_motions) | ||
first_label = labels[first_idx] | ||
if not label_quotas[first_label]: | ||
continue | ||
|
||
second_idx = np.random.randint(0, num_motions) | ||
second_label = labels[second_idx] | ||
while first_label != second_label: | ||
second_idx = np.random.randint(0, num_motions) | ||
second_label = labels[second_idx] | ||
|
||
label_quotas[first_label] -= 1 | ||
|
||
first_activation = activations[first_idx, :] | ||
second_activation = activations[second_idx, :] | ||
multimodality += torch.dist(first_activation, | ||
second_activation) | ||
|
||
multimodality /= (multimodality_times * num_labels) | ||
else: | ||
multimodality = torch.tensor(np.nan) | ||
|
||
return diversity.item(), multimodality.item() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import torch | ||
import numpy as np | ||
from .models import load_classifier, load_classifier_for_fid | ||
from .accuracy import calculate_accuracy | ||
from .fid import calculate_fid | ||
from .diversity import calculate_diversity_multimodality | ||
|
||
|
||
class A2MEvaluation: | ||
def __init__(self, device): | ||
dataset_opt = {"input_size_raw": 72, "joints_num": 24, "num_classes": 12} | ||
|
||
self.input_size_raw = dataset_opt["input_size_raw"] | ||
self.num_classes = dataset_opt["num_classes"] | ||
self.device = device | ||
|
||
self.gru_classifier_for_fid = load_classifier_for_fid(self.input_size_raw, self.num_classes, device).eval() | ||
self.gru_classifier = load_classifier(self.input_size_raw, self.num_classes, device).eval() | ||
|
||
def compute_features(self, model, motionloader): | ||
# calculate_activations_labels function from action2motion | ||
activations = [] | ||
labels = [] | ||
with torch.no_grad(): | ||
for idx, batch in enumerate(motionloader): | ||
activations.append(self.gru_classifier_for_fid(batch["output_xyz"], lengths=batch["lengths"])) | ||
if model.cond_mode != 'no_cond': | ||
labels.append(batch["y"]) | ||
activations = torch.cat(activations, dim=0) | ||
if model.cond_mode != 'no_cond': | ||
labels = torch.cat(labels, dim=0) | ||
return activations, labels | ||
|
||
@staticmethod | ||
def calculate_activation_statistics(activations): | ||
activations = activations.cpu().numpy() | ||
mu = np.mean(activations, axis=0) | ||
sigma = np.cov(activations, rowvar=False) | ||
return mu, sigma | ||
|
||
def evaluate(self, model, loaders): | ||
|
||
def print_logs(metric, key): | ||
print(f"Computing action2motion {metric} on the {key} loader ...") | ||
|
||
metrics = {} | ||
|
||
computedfeats = {} | ||
for key, loader in loaders.items(): | ||
metric = "accuracy" | ||
print_logs(metric, key) | ||
mkey = f"{metric}_{key}" | ||
if model.cond_mode != 'no_cond': | ||
metrics[mkey], _ = calculate_accuracy(model, loader, | ||
self.num_classes, | ||
self.gru_classifier, self.device) | ||
else: | ||
metrics[mkey] = np.nan | ||
|
||
# features for diversity | ||
print_logs("features", key) | ||
feats, labels = self.compute_features(model, loader) | ||
print_logs("stats", key) | ||
stats = self.calculate_activation_statistics(feats) | ||
|
||
computedfeats[key] = {"feats": feats, | ||
"labels": labels, | ||
"stats": stats} | ||
|
||
print_logs("diversity", key) | ||
ret = calculate_diversity_multimodality(feats, labels, self.num_classes, unconstrained=(model.cond_mode=='no_cond')) | ||
metrics[f"diversity_{key}"], metrics[f"multimodality_{key}"] = ret | ||
|
||
# taking the stats of the ground truth and remove it from the computed feats | ||
gtstats = computedfeats["gt"]["stats"] | ||
# computing fid | ||
for key, loader in computedfeats.items(): | ||
metric = "fid" | ||
mkey = f"{metric}_{key}" | ||
|
||
stats = computedfeats[key]["stats"] | ||
metrics[mkey] = float(calculate_fid(gtstats, stats)) | ||
|
||
return metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import numpy as np | ||
from scipy import linalg | ||
|
||
|
||
# from action2motion | ||
def calculate_fid(statistics_1, statistics_2): | ||
return calculate_frechet_distance(statistics_1[0], statistics_1[1], | ||
statistics_2[0], statistics_2[1]) | ||
|
||
|
||
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): | ||
"""Numpy implementation of the Frechet Distance. | ||
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) | ||
and X_2 ~ N(mu_2, C_2) is | ||
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). | ||
Stable version by Dougal J. Sutherland. | ||
Params: | ||
-- mu1 : Numpy array containing the activations of a layer of the | ||
inception net (like returned by the function 'get_predictions') | ||
for generated samples. | ||
-- mu2 : The sample mean over activations, precalculated on an | ||
representative data set. | ||
-- sigma1: The covariance matrix over activations for generated samples. | ||
-- sigma2: The covariance matrix over activations, precalculated on an | ||
representative data set. | ||
Returns: | ||
-- : The Frechet Distance. | ||
""" | ||
|
||
mu1 = np.atleast_1d(mu1) | ||
mu2 = np.atleast_1d(mu2) | ||
|
||
sigma1 = np.atleast_2d(sigma1) | ||
sigma2 = np.atleast_2d(sigma2) | ||
|
||
assert mu1.shape == mu2.shape, \ | ||
'Training and test mean vectors have different lengths' | ||
assert sigma1.shape == sigma2.shape, \ | ||
'Training and test covariances have different dimensions' | ||
|
||
diff = mu1 - mu2 | ||
|
||
# Product might be almost singular | ||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | ||
if not np.isfinite(covmean).all(): | ||
msg = ('fid calculation produces singular product; ' | ||
'adding %s to diagonal of cov estimates') % eps | ||
print(msg) | ||
offset = np.eye(sigma1.shape[0]) * eps | ||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | ||
|
||
# Numerical error might give slight imaginary component | ||
if np.iscomplexobj(covmean): | ||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | ||
m = np.max(np.abs(covmean.imag)) | ||
raise ValueError('Imaginary component {}'.format(m)) | ||
covmean = covmean.real | ||
|
||
tr_covmean = np.trace(covmean) | ||
|
||
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) |
Oops, something went wrong.