forked from KinWaiCheuk/nnAudio
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Spectrogram.py divided into multiple files inside nnAudio.features
- Loading branch information
Showing
9 changed files
with
2,441 additions
and
2,685 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
""" | ||
Module containing all the spectrogram classes | ||
""" | ||
|
||
# 0.2.0 | ||
from ..librosa_functions import * | ||
from ..utils import * | ||
from .cfp import * | ||
from .cqt import * | ||
from .gammatone import * | ||
from .griffin_lim import * | ||
from .mel import * | ||
from .stft import * |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import torch.nn as nn | ||
import torch | ||
import numpy as np | ||
from time import time | ||
from ..utils import * | ||
|
||
|
||
class Gammatonegram(nn.Module): | ||
""" | ||
This function is to calculate the Gammatonegram of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. This class inherits from ``nn.Module``, therefore, the usage is same as ``nn.Module``. | ||
Parameters | ||
---------- | ||
sr : int | ||
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency. | ||
n_fft : int | ||
The window size for the STFT. Default value is 2048 | ||
n_mels : int | ||
The number of Gammatonegram filter banks. The filter banks maps the n_fft to Gammatone bins. Default value is 64 | ||
hop_length : int | ||
The hop (or stride) size. Default value is 512. | ||
window : str | ||
The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann' | ||
center : bool | ||
Putting the STFT keneral at the center of the time-step or not. If ``False``, the time index is the beginning of the STFT kernel, if ``True``, the time index is the center of the STFT kernel. Default value if ``True``. | ||
pad_mode : str | ||
The padding method. Default value is 'reflect'. | ||
htk : bool | ||
When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the Mel scale is logarithmic. The default value is ``False`` | ||
fmin : int | ||
The starting frequency for the lowest Gammatone filter bank | ||
fmax : int | ||
The ending frequency for the highest Gammatone filter bank | ||
trainable_mel : bool | ||
Determine if the Gammatone filter banks are trainable or not. If ``True``, the gradients for Mel filter banks will also be caluclated and the Mel filter banks will be updated during model training. Default value is ``False`` | ||
trainable_STFT : bool | ||
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT kernels will also be caluclated and the STFT kernels will be updated during model training. Default value is ``False`` | ||
verbose : bool | ||
If ``True``, it shows layer information. If ``False``, it suppresses all prints | ||
Returns | ||
------- | ||
spectrogram : torch.tensor | ||
It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``. | ||
Examples | ||
-------- | ||
>>> spec_layer = Spectrogram.Gammatonegram() | ||
>>> specs = spec_layer(x) | ||
""" | ||
|
||
def __init__(self, sr=44100, n_fft=2048, n_bins=64, hop_length=512, window='hann', center=True, pad_mode='reflect', | ||
power=2.0, htk=False, fmin=20.0, fmax=None, norm=1, trainable_bins=False, trainable_STFT=False, | ||
verbose=True): | ||
super(Gammatonegram, self).__init__() | ||
self.stride = hop_length | ||
self.center = center | ||
self.pad_mode = pad_mode | ||
self.n_fft = n_fft | ||
self.power = power | ||
|
||
# Create filter windows for stft | ||
start = time() | ||
wsin, wcos, self.bins2freq, _, _ = create_fourier_kernels(n_fft, freq_bins=None, window=window, freq_scale='no', | ||
sr=sr) | ||
|
||
wsin = torch.tensor(wsin, dtype=torch.float) | ||
wcos = torch.tensor(wcos, dtype=torch.float) | ||
|
||
if trainable_STFT: | ||
wsin = nn.Parameter(wsin, requires_grad=trainable_STFT) | ||
wcos = nn.Parameter(wcos, requires_grad=trainable_STFT) | ||
self.register_parameter('wsin', wsin) | ||
self.register_parameter('wcos', wcos) | ||
else: | ||
self.register_buffer('wsin', wsin) | ||
self.register_buffer('wcos', wcos) | ||
|
||
# Creating kenral for Gammatone spectrogram | ||
start = time() | ||
gammatone_basis = gammatone(sr, n_fft, n_bins, fmin, fmax) | ||
gammatone_basis = torch.tensor(gammatone_basis) | ||
|
||
if verbose == True: | ||
print("STFT filter created, time used = {:.4f} seconds".format(time() - start)) | ||
print("Gammatone filter created, time used = {:.4f} seconds".format(time() - start)) | ||
else: | ||
pass | ||
# Making everything nn.Prarmeter, so that this model can support nn.DataParallel | ||
|
||
if trainable_bins: | ||
gammatone_basis = nn.Parameter(gammatone_basis, requires_grad=trainable_bins) | ||
self.register_parameter('gammatone_basis', gammatone_basis) | ||
else: | ||
self.register_buffer('gammatone_basis', gammatone_basis) | ||
|
||
# if trainable_mel==True: | ||
# self.mel_basis = nn.Parameter(self.mel_basis) | ||
# if trainable_STFT==True: | ||
# self.wsin = nn.Parameter(self.wsin) | ||
# self.wcos = nn.Parameter(self.wcos) | ||
|
||
def forward(self, x): | ||
x = broadcast_dim(x) | ||
if self.center: | ||
if self.pad_mode == 'constant': | ||
padding = nn.ConstantPad1d(self.n_fft // 2, 0) | ||
elif self.pad_mode == 'reflect': | ||
padding = nn.ReflectionPad1d(self.n_fft // 2) | ||
|
||
x = padding(x) | ||
|
||
spec = torch.sqrt(conv1d(x, self.wsin, stride=self.stride).pow(2) \ | ||
+ conv1d(x, self.wcos, stride=self.stride).pow(2)) ** self.power # Doing STFT by using conv1d | ||
|
||
gammatonespec = torch.matmul(self.gammatone_basis, spec) | ||
return gammatonespec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import torch.nn as nn | ||
import torch | ||
import numpy as np | ||
from time import time | ||
from ..utils import * | ||
|
||
|
||
class Griffin_Lim(nn.Module): | ||
""" | ||
Converting Magnitude spectrograms back to waveforms based on the "fast Griffin-Lim"[1]. | ||
This Griffin Lim is a direct clone from librosa.griffinlim. | ||
[1] Perraudin, N., Balazs, P., & Søndergaard, P. L. “A fast Griffin-Lim algorithm,” | ||
IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4), Oct. 2013. | ||
Parameters | ||
---------- | ||
n_fft : int | ||
The window size. Default value is 2048. | ||
n_iter=32 : int | ||
The number of iterations for Griffin-Lim. The default value is ``32`` | ||
hop_length : int | ||
The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``. | ||
Please make sure the value is the same as the forward STFT. | ||
window : str | ||
The windowing function for iSTFT. It uses ``scipy.signal.get_window``, please refer to | ||
scipy documentation for possible windowing functions. The default value is 'hann'. | ||
Please make sure the value is the same as the forward STFT. | ||
center : bool | ||
Putting the iSTFT keneral at the center of the time-step or not. If ``False``, the time | ||
index is the beginning of the iSTFT kernel, if ``True``, the time index is the center of | ||
the iSTFT kernel. Default value if ``True``. | ||
Please make sure the value is the same as the forward STFT. | ||
momentum : float | ||
The momentum for the update rule. The default value is ``0.99``. | ||
device : str | ||
Choose which device to initialize this layer. Default value is 'cpu' | ||
""" | ||
|
||
def __init__(self, | ||
n_fft, | ||
n_iter=32, | ||
hop_length=None, | ||
win_length=None, | ||
window='hann', | ||
center=True, | ||
pad_mode='reflect', | ||
momentum=0.99, | ||
device='cpu'): | ||
super().__init__() | ||
|
||
self.n_fft = n_fft | ||
self.win_length = win_length | ||
self.n_iter = n_iter | ||
self.center = center | ||
self.pad_mode = pad_mode | ||
self.momentum = momentum | ||
self.device = device | ||
if win_length == None: | ||
self.win_length = n_fft | ||
else: | ||
self.win_length = win_length | ||
if hop_length == None: | ||
self.hop_length = n_fft // 4 | ||
else: | ||
self.hop_length = hop_length | ||
|
||
# Creating window function for stft and istft later | ||
self.w = torch.tensor(get_window(window, | ||
int(self.win_length), | ||
fftbins=True), | ||
device=device).float() | ||
|
||
def forward(self, S): | ||
""" | ||
Convert a batch of magnitude spectrograms to waveforms. | ||
Parameters | ||
---------- | ||
S : torch tensor | ||
Spectrogram of the shape ``(batch, n_fft//2+1, timesteps)`` | ||
""" | ||
|
||
assert S.dim() == 3, "Please make sure your input is in the shape of (batch, freq_bins, timesteps)" | ||
|
||
# Initializing Random Phase | ||
rand_phase = torch.randn(*S.shape, device=self.device) | ||
angles = torch.empty((*S.shape, 2), device=self.device) | ||
angles[:, :, :, 0] = torch.cos(2 * np.pi * rand_phase) | ||
angles[:, :, :, 1] = torch.sin(2 * np.pi * rand_phase) | ||
|
||
# Initializing the rebuilt magnitude spectrogram | ||
rebuilt = torch.zeros(*angles.shape, device=self.device) | ||
|
||
for _ in range(self.n_iter): | ||
tprev = rebuilt # Saving previous rebuilt magnitude spec | ||
|
||
# spec2wav conversion | ||
# print(f'win_length={self.win_length}\tw={self.w.shape}') | ||
inverse = torch.istft(S.unsqueeze(-1) * angles, | ||
self.n_fft, | ||
self.hop_length, | ||
win_length=self.win_length, | ||
window=self.w, | ||
center=self.center) | ||
# wav2spec conversion | ||
rebuilt = torch.stft(inverse, | ||
self.n_fft, | ||
self.hop_length, | ||
win_length=self.win_length, | ||
window=self.w, | ||
pad_mode=self.pad_mode) | ||
|
||
# Phase update rule | ||
angles[:, :, :] = rebuilt[:, :, :] - (self.momentum / (1 + self.momentum)) * tprev[:, :, :] | ||
|
||
# Phase normalization | ||
angles = angles.div(torch.sqrt(angles.pow(2).sum(-1)).unsqueeze(-1) + 1e-16) # normalizing the phase | ||
|
||
# Using the final phase to reconstruct the waveforms | ||
inverse = torch.istft(S.unsqueeze(-1) * angles, | ||
self.n_fft, | ||
self.hop_length, | ||
win_length=self.win_length, | ||
window=self.w, | ||
center=self.center) | ||
return inverse |
Oops, something went wrong.