Skip to content

Commit

Permalink
Factor out a common gaussian_frechet_distance function
Browse files Browse the repository at this point in the history
Summary: Instead of having duplicate implementations for the Frechet Distance between two Gaussians for FID and FAD, let's just use the same implementation.

Reviewed By: JKSenthil

Differential Revision: D56520860

fbshipit-source-id: 3c6423a648f41576be2fd731be61bfcf21da1fc7
  • Loading branch information
alanhdu authored and facebook-github-bot committed Apr 24, 2024
1 parent 3ea2f6e commit e138259
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 61 deletions.
28 changes: 8 additions & 20 deletions torcheval/metrics/audio/fad.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,12 @@
from typing import Any, Callable, Iterable, Optional, Union

import torch

try:
from torchaudio.functional import frechet_distance

_TORCHAUDIO_AVAILABLE = True
except ImportError:
_TORCHAUDIO_AVAILABLE = False

from torcheval.metrics.functional.frechet import gaussian_frechet_distance
from torcheval.metrics.metric import Metric

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.


def _validate_torchaudio_available() -> None:
if not _TORCHAUDIO_AVAILABLE:
raise RuntimeError(
"TorchAudio is required. Please make sure ``torchaudio`` is installed."
)


class FrechetAudioDistance(Metric[torch.Tensor]):
"""Computes the Fréchet distance between predicted and target audio waveforms.
Expand All @@ -50,8 +36,6 @@ def __init__(
embedding_dim: int,
device: Optional[torch.device] = None,
) -> None:
_validate_torchaudio_available()

super().__init__(device=device)

self.preproc = preproc
Expand Down Expand Up @@ -120,8 +104,13 @@ def compute(self: "FrechetAudioDistance") -> torch.Tensor:
pred_cov = self.pred_cov_partial / (self.pred_n - 1) - pred_mean.T @ (
pred_mean
) * self.pred_n / (self.pred_n - 1)
return frechet_distance(
pred_mean.squeeze(0), pred_cov, target_mean.squeeze(0), target_cov
return gaussian_frechet_distance(
pred_mean.squeeze(0),
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`.
pred_cov,
target_mean.squeeze(0),
# pyre-fixme[6]: For 4th argument expected `Tensor` but got `float`.
target_cov,
)

@torch.inference_mode()
Expand Down Expand Up @@ -165,7 +154,6 @@ def with_vggish(device: Optional[torch.device] = None) -> "FrechetAudioDistance"
Returns:
FrechetAudioDistance: Instance of FrechetAudioDistance preloaded with TorchAudio's pretrained VGGish model.
"""
_validate_torchaudio_available()
try:
from torchaudio.prototype.pipelines import VGGISH
except ImportError:
Expand Down
2 changes: 2 additions & 0 deletions torcheval/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
multilabel_recall_at_fixed_precision,
topk_multilabel_accuracy,
)
from torcheval.metrics.functional.frechet import gaussian_frechet_distance
from torcheval.metrics.functional.image import peak_signal_noise_ratio
from torcheval.metrics.functional.ranking import (
click_through_rate,
Expand Down Expand Up @@ -78,6 +79,7 @@
"bleu_score",
"click_through_rate",
"frequency_at_k",
"gaussian_frechet_distance",
"hit_rate",
"mean",
"mean_squared_error",
Expand Down
56 changes: 56 additions & 0 deletions torcheval/metrics/functional/frechet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
import torch


def gaussian_frechet_distance(
mu_x: torch.Tensor, cov_x: torch.Tensor, mu_y: torch.Tensor, cov_y: torch.Tensor
) -> torch.Tensor:
r"""Computes the Fréchet distance between two multivariate normal distributions :cite:`dowson1982frechet`.
The Fréchet distance is also known as the Wasserstein-2 distance.
Concretely, for multivariate Gaussians :math:`X(\mu_X, \cov_X)`
and :math:`Y(\mu_Y, \cov_Y)`, the function computes and returns :math:`F` as
.. math::
F(X, Y) = || \mu_X - \mu_Y ||_2^2
+ \text{Tr}\left( \cov_X + \cov_Y - 2 \sqrt{\cov_X \cov_Y} \right)
Args:
mu_x (torch.Tensor): mean :math:`\mu_X` of multivariate Gaussian :math:`X`, with shape `(N,)`.
cov_x (torch.Tensor): covariance matrix :math:`\cov_X` of :math:`X`, with shape `(N, N)`.
mu_y (torch.Tensor): mean :math:`\mu_Y` of multivariate Gaussian :math:`Y`, with shape `(N,)`.
cov_y (torch.Tensor): covariance matrix :math:`\cov_Y` of :math:`Y`, with shape `(N, N)`.
Returns:
torch.Tensor: the Fréchet distance between :math:`X` and :math:`Y`.
"""
if mu_x.ndim != 1:
msg = f"Input mu_x must be one-dimensional; got dimension {mu_x.ndim}."
raise ValueError(msg)
if mu_y.ndim != 1:
msg = f"Input mu_y must be one-dimensional; got dimension {mu_y.ndim}."
raise ValueError(msg)
if cov_x.ndim != 2:
msg = f"Input cov_x must be two-dimensional; got dimension {cov_x.ndim}."
raise ValueError(msg)
if cov_y.ndim != 2:
msg = f"Input cov_x must be two-dimensional; got dimension {cov_y.ndim}."
raise ValueError(msg)
if mu_x.shape != mu_y.shape:
msg = f"Inputs mu_x and mu_y must have the same shape; got {mu_x.shape} and {mu_y.shape}."
raise ValueError(msg)
if cov_x.shape != cov_y.shape:
msg = f"Inputs cov_x and cov_y must have the same shape; got {cov_x.shape} and {cov_y.shape}."
raise ValueError(msg)

a = (mu_x - mu_y).square().sum()
b = cov_x.trace() + cov_y.trace()
c = torch.linalg.eigvals(cov_x @ cov_y).sqrt().real.sum()
return a + b - 2 * c
43 changes: 2 additions & 41 deletions torcheval/metrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torcheval.metrics.functional.frechet import gaussian_frechet_distance
from torcheval.metrics.metric import Metric

if find_spec("torchvision") is not None:
Expand Down Expand Up @@ -196,51 +197,11 @@ def compute(self: TFrechetInceptionDistance) -> Tensor:
fake_cov = fake_cov_num / (self.num_fake_images - 1)

# Compute the Frechet Distance between the distributions
fid = self._calculate_frechet_distance(
fid = gaussian_frechet_distance(
real_mean.squeeze(), real_cov, fake_mean.squeeze(), fake_cov
)
return fid

def _calculate_frechet_distance(
self: TFrechetInceptionDistance,
mu1: Tensor,
sigma1: Tensor,
mu2: Tensor,
sigma2: Tensor,
) -> Tensor:
"""
Calculate the Frechet Distance between two multivariate Gaussian distributions.
Args:
mu1 (Tensor): The mean of the first distribution.
sigma1 (Tensor): The covariance matrix of the first distribution.
mu2 (Tensor): The mean of the second distribution.
sigma2 (Tensor): The covariance matrix of the second distribution.
Returns:
tensor: The Frechet Distance between the two distributions.
"""

# Compute the squared distance between the means
mean_diff = mu1 - mu2
mean_diff_squared = mean_diff.square().sum(dim=-1)

# Calculate the sum of the traces of both covariance matrices
trace_sum = sigma1.trace() + sigma2.trace()

# Compute the eigenvalues of the matrix product of the real and fake covariance matrices
sigma_mm = torch.matmul(sigma1, sigma2)
eigenvals = torch.linalg.eigvals(sigma_mm)

# Take the square root of each eigenvalue and take its sum
sqrt_eigenvals_sum = eigenvals.sqrt().real.sum(dim=-1)

# Calculate the FID using the squared distance between the means,
# the sum of the traces of the covariance matrices, and the sum of the square roots of the eigenvalues
fid = mean_diff_squared + trace_sum - 2 * sqrt_eigenvals_sum

return fid

def _FID_parameter_check(
self: TFrechetInceptionDistance,
model: Optional[nn.Module],
Expand Down

0 comments on commit e138259

Please sign in to comment.