Skip to content

Commit

Permalink
add C-SI-SNR (Lightning-AI#1785)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
quancs and pre-commit-ci[bot] authored May 22, 2023
1 parent 6014ade commit d4a3932
Show file tree
Hide file tree
Showing 9 changed files with 313 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new global arg `compute_with_cache` to control caching behaviour after `compute` method ([#1754](https://github.com/Lightning-AI/torchmetrics/pull/1754))


- Added `ComplexScaleInvariantSignalNoiseRatio` for audio package ([#1785](https://github.com/Lightning-AI/torchmetrics/pull/1785))


- Added `Running` wrapper for calculate running statistics ([#1752](https://github.com/Lightning-AI/torchmetrics/pull/1752))


Expand Down
23 changes: 23 additions & 0 deletions docs/source/audio/complex_scale_invariant_signal_noise_ratio.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Complex Scale-Invariant Signal-to-Noise Ratio (C-SI-SNR)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

########################################################
Complex Scale-Invariant Signal-to-Noise Ratio (C-SI-SNR)
########################################################

Module Interface
________________

.. autoclass:: torchmetrics.audio.ComplexScaleInvariantSignalNoiseRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.audio.complex_scale_invariant_signal_noise_ratio
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
.. _sdr ref2: https://arxiv.org/abs/2110.06440
.. _Scale-invariant signal-to-distortion ratio: https://arxiv.org/abs/1811.02508
.. _Scale-invariant signal-to-noise ratio: https://arxiv.org/abs/1711.00541
.. _Complex scale-invariant signal-to-noise ratio: https://arxiv.org/abs/2011.09162
.. _Signal-to-noise ratio: https://arxiv.org/abs/1811.02508
.. _Permutation invariant training: https://arxiv.org/abs/1607.00325
.. _ranking ref1: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
Expand Down
7 changes: 6 additions & 1 deletion src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.
from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
from torchmetrics.audio.snr import (
ComplexScaleInvariantSignalNoiseRatio,
ScaleInvariantSignalNoiseRatio,
SignalNoiseRatio,
)
from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE

__all__ = [
Expand All @@ -22,6 +26,7 @@
"SignalDistortionRatio",
"ScaleInvariantSignalNoiseRatio",
"SignalNoiseRatio",
"ComplexScaleInvariantSignalNoiseRatio",
]

if _PESQ_AVAILABLE:
Expand Down
122 changes: 120 additions & 2 deletions src/torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@

from torch import Tensor, tensor

from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio
from torchmetrics.functional.audio.snr import (
complex_scale_invariant_signal_noise_ratio,
scale_invariant_signal_noise_ratio,
signal_noise_ratio,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SignalNoiseRatio.plot", "ScaleInvariantSignalNoiseRatio.plot"]
__doctest_skip__ = [
"SignalNoiseRatio.plot",
"ScaleInvariantSignalNoiseRatio.plot",
"ComplexScaleInvariantSignalNoiseRatio.plot",
]


class SignalNoiseRatio(Metric):
Expand Down Expand Up @@ -151,6 +159,7 @@ class ScaleInvariantSignalNoiseRatio(Metric):
if target and preds have a different shape
Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
Expand Down Expand Up @@ -225,3 +234,112 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class ComplexScaleInvariantSignalNoiseRatio(Metric):
"""Calculate `Complex scale-invariant signal-to-noise ratio`_ (C-SI-SNR) metric for evaluating quality of audio.
As input to `forward` and `update` the metric accepts the following input
- ``preds`` (:class:`~torch.Tensor`): real/complex float tensor with shape ``(..., frequency, time, 2)``\
/ ``(..., frequency, time)``
- ``target`` (: :class:`~torch.Tensor`): real/complex float tensor with shape ``(..., frequency, time, 2)``\
/ ``(..., frequency, time)``
As output of `forward` and `compute` the metric returns the following output
- ``c_si_snr`` (: :class:`~torch.Tensor`): float scalar tensor with average C-SI-SNR value over samples
Args:
zero_mean: if to zero mean target and preds or not
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``zero_mean`` is not an bool
TypeError:
If ``preds`` is not the shape (..., frequency, time, 2) (after being converted to real if it is complex).
If ``preds`` and ``target`` does not have the same shape.
Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn((1,257,100,2))
>>> target = torch.randn((1,257,100,2))
>>> c_si_snr = ComplexScaleInvariantSignalNoiseRatio()
>>> c_si_snr(preds, target)
tensor(-63.4849)
"""

is_differentiable = True
sum: Tensor
num: Tensor
higher_is_better = True
plot_lower_bound: Optional[float] = None
plot_upper_bound: Optional[float] = None

def __init__(
self,
zero_mean: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not isinstance(zero_mean, bool):
raise ValueError(f"Expected argument `zero_mean` to be an bool, but got {zero_mean}")
self.zero_mean = zero_mean

self.add_state("sum", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("num", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
v = complex_scale_invariant_signal_noise_ratio(preds=preds, target=target, zero_mean=self.zero_mean)

self.sum += v.sum()
self.num += v.numel()

def compute(self) -> Tensor:
"""Compute metric."""
return self.sum / self.num

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> metric = ComplexScaleInvariantSignalNoiseRatio()
>>> metric.update(torch.rand(1,257,100,2), torch.rand(1,257,100,2))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> metric = ComplexScaleInvariantSignalNoiseRatio()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(1,257,100,2), torch.rand(1,257,100,2)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
16 changes: 14 additions & 2 deletions src/torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@
# limitations under the License.
from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate
from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio
from torchmetrics.functional.audio.snr import (
complex_scale_invariant_signal_noise_ratio,
scale_invariant_signal_noise_ratio,
signal_noise_ratio,
)
from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE

__all__ = []
__all__ = [
"permutation_invariant_training",
"pit_permutate",
"scale_invariant_signal_distortion_ratio",
"signal_distortion_ratio",
"scale_invariant_signal_noise_ratio",
"signal_noise_ratio",
"complex_scale_invariant_signal_noise_ratio",
]

if _PESQ_AVAILABLE:
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
Expand Down
42 changes: 42 additions & 0 deletions src/torchmetrics/functional/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,45 @@ def scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
tensor(15.0918)
"""
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=True)


def complex_scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
"""`Complex scale-invariant signal-to-noise ratio`_ (C-SI-SNR).
Args:
preds: real/complex float tensor with shape ``(..., frequency, time, 2)``/``(..., frequency, time)``
target: real/complex float tensor with shape ``(..., frequency, time, 2)``/``(..., frequency, time)``
zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics
Returns:
Float tensor with shape ``(...,)`` of C-SI-SNR values per sample
Raises:
RuntimeError:
If ``preds`` is not the shape (..., frequency, time, 2) (after being converted to real if it is complex).
If ``preds`` and ``target`` does not have the same shape.
Example:
>>> import torch
>>> from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn((1,257,100,2))
>>> target = torch.randn((1,257,100,2))
>>> complex_scale_invariant_signal_noise_ratio(preds, target)
tensor([-63.4849])
"""
if preds.is_complex():
preds = torch.view_as_real(preds)
if target.is_complex():
target = torch.view_as_real(target)

if (preds.ndim < 3 or preds.shape[-1] != 2) or (target.ndim < 3 or target.shape[-1] != 2):
raise RuntimeError(
"Predictions and targets are expected to have the shape (..., frequency, time, 2),"
" but got {preds.shape} and {target.shape}."
)

preds = preds.reshape(*preds.shape[:-3], -1)
target = target.reshape(*target.shape[:-3], -1)

return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=zero_mean)
97 changes: 97 additions & 0 deletions tests/unittests/audio/test_c_si_snr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple

import pytest
import torch
from scipy.io import wavfile
from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 129, 20, 2),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 129, 20, 2),
)


@pytest.mark.parametrize(
"preds, target, ref_metric, zero_mean",
[
(inputs.preds, inputs.target, None, True),
(inputs.preds, inputs.target, None, False),
],
)
class TestComplexSISNR(MetricTester):
"""Test class for `ComplexScaleInvariantSignalNoiseRatio` metric."""

atol = 1e-2

def test_c_si_snr_differentiability(self, preds, target, ref_metric, zero_mean):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=ComplexScaleInvariantSignalNoiseRatio,
metric_functional=complex_scale_invariant_signal_noise_ratio,
metric_args={"zero_mean": zero_mean},
)

def test_c_si_sdr_half_cpu(self, preds, target, ref_metric, zero_mean):
"""Test dtype support of the metric on CPU."""
pytest.xfail("C-SI-SDR metric does not support cpu + half precision")

def test_c_si_sdr_half_gpu(self, preds, target, ref_metric, zero_mean):
"""Test dtype support of the metric on GPU."""
pytest.xfail("C-SI-SDR metric does not support gpu + half precision")


def test_on_real_audio():
"""Test that metric works as expected on real audio signals."""
rate, ref = wavfile.read(_SAMPLE_AUDIO_SPEECH)
rate, deg = wavfile.read(_SAMPLE_AUDIO_SPEECH_BAB_DB)
ref = torch.tensor(ref, dtype=torch.float32)
deg = torch.tensor(deg, dtype=torch.float32)
ref_stft = torch.stft(ref, n_fft=256, hop_length=128, return_complex=True)
deg_stft = torch.stft(deg, n_fft=256, hop_length=128, return_complex=True)

v = complex_scale_invariant_signal_noise_ratio(deg_stft, ref_stft, zero_mean=False)
assert torch.allclose(v, torch.tensor(0.03019072115421295, dtype=v.dtype), atol=1e-4), v
v = complex_scale_invariant_signal_noise_ratio(deg_stft, ref_stft, zero_mean=True)
assert torch.allclose(v, torch.tensor(0.030391741544008255, dtype=v.dtype), atol=1e-4), v


def test_error_on_incorrect_shape(metric_class=ComplexScaleInvariantSignalNoiseRatio):
"""Test that error is raised on incorrect shapes of input."""
metric = metric_class()
with pytest.raises(
RuntimeError,
match="Predictions and targets are expected to have the shape (..., frequency, time, 2)*",
):
metric(torch.randn(100), torch.randn(50))


def test_error_on_different_shape(metric_class=ComplexScaleInvariantSignalNoiseRatio):
"""Test that error is raised on different shapes of input."""
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape*"):
metric(torch.randn(129, 100, 2), torch.randn(129, 101, 2))
7 changes: 7 additions & 0 deletions tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchmetrics import MetricCollection
from torchmetrics.aggregation import MaxMetric, MeanMetric, MinMetric, SumMetric
from torchmetrics.audio import (
ComplexScaleInvariantSignalNoiseRatio,
ScaleInvariantSignalDistortionRatio,
ScaleInvariantSignalNoiseRatio,
ShortTimeObjectiveIntelligibility,
Expand Down Expand Up @@ -283,6 +284,12 @@
ScaleInvariantSignalDistortionRatio, _rand_input, _rand_input, id="scale_invariant_signal_distortion_ratio"
),
pytest.param(SignalNoiseRatio, _rand_input, _rand_input, id="signal_noise_ratio"),
pytest.param(
ComplexScaleInvariantSignalNoiseRatio,
lambda: torch.randn(10, 3, 5, 2),
lambda: torch.randn(10, 3, 5, 2),
id="complex scale invariant signal noise ratio",
),
pytest.param(ScaleInvariantSignalNoiseRatio, _rand_input, _rand_input, id="scale_invariant_signal_noise_ratio"),
pytest.param(
partial(ShortTimeObjectiveIntelligibility, fs=8000, extended=False),
Expand Down

0 comments on commit d4a3932

Please sign in to comment.