Skip to content

Commit

Permalink
Spectrogram.py divided into multiple files inside nnAudio.features
Browse files Browse the repository at this point in the history
  • Loading branch information
migperfer committed Oct 30, 2021
1 parent 1c9d082 commit 7398c9e
Show file tree
Hide file tree
Showing 9 changed files with 2,441 additions and 2,685 deletions.
2,688 changes: 4 additions & 2,684 deletions Installation/nnAudio/Spectrogram.py

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions Installation/nnAudio/features/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Module containing all the spectrogram classes
"""

# 0.2.0
from ..librosa_functions import *
from ..utils import *
from .cfp import *
from .cqt import *
from .gammatone import *
from .griffin_lim import *
from .mel import *
from .stft import *
396 changes: 396 additions & 0 deletions Installation/nnAudio/features/cfp.py

Large diffs are not rendered by default.

992 changes: 992 additions & 0 deletions Installation/nnAudio/features/cqt.py

Large diffs are not rendered by default.

120 changes: 120 additions & 0 deletions Installation/nnAudio/features/gammatone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import torch.nn as nn
import torch
import numpy as np
from time import time
from ..utils import *


class Gammatonegram(nn.Module):
"""
This function is to calculate the Gammatonegram of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. This class inherits from ``nn.Module``, therefore, the usage is same as ``nn.Module``.
Parameters
----------
sr : int
The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency.
n_fft : int
The window size for the STFT. Default value is 2048
n_mels : int
The number of Gammatonegram filter banks. The filter banks maps the n_fft to Gammatone bins. Default value is 64
hop_length : int
The hop (or stride) size. Default value is 512.
window : str
The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann'
center : bool
Putting the STFT keneral at the center of the time-step or not. If ``False``, the time index is the beginning of the STFT kernel, if ``True``, the time index is the center of the STFT kernel. Default value if ``True``.
pad_mode : str
The padding method. Default value is 'reflect'.
htk : bool
When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the Mel scale is logarithmic. The default value is ``False``
fmin : int
The starting frequency for the lowest Gammatone filter bank
fmax : int
The ending frequency for the highest Gammatone filter bank
trainable_mel : bool
Determine if the Gammatone filter banks are trainable or not. If ``True``, the gradients for Mel filter banks will also be caluclated and the Mel filter banks will be updated during model training. Default value is ``False``
trainable_STFT : bool
Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT kernels will also be caluclated and the STFT kernels will be updated during model training. Default value is ``False``
verbose : bool
If ``True``, it shows layer information. If ``False``, it suppresses all prints
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``.
Examples
--------
>>> spec_layer = Spectrogram.Gammatonegram()
>>> specs = spec_layer(x)
"""

def __init__(self, sr=44100, n_fft=2048, n_bins=64, hop_length=512, window='hann', center=True, pad_mode='reflect',
power=2.0, htk=False, fmin=20.0, fmax=None, norm=1, trainable_bins=False, trainable_STFT=False,
verbose=True):
super(Gammatonegram, self).__init__()
self.stride = hop_length
self.center = center
self.pad_mode = pad_mode
self.n_fft = n_fft
self.power = power

# Create filter windows for stft
start = time()
wsin, wcos, self.bins2freq, _, _ = create_fourier_kernels(n_fft, freq_bins=None, window=window, freq_scale='no',
sr=sr)

wsin = torch.tensor(wsin, dtype=torch.float)
wcos = torch.tensor(wcos, dtype=torch.float)

if trainable_STFT:
wsin = nn.Parameter(wsin, requires_grad=trainable_STFT)
wcos = nn.Parameter(wcos, requires_grad=trainable_STFT)
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
else:
self.register_buffer('wsin', wsin)
self.register_buffer('wcos', wcos)

# Creating kenral for Gammatone spectrogram
start = time()
gammatone_basis = gammatone(sr, n_fft, n_bins, fmin, fmax)
gammatone_basis = torch.tensor(gammatone_basis)

if verbose == True:
print("STFT filter created, time used = {:.4f} seconds".format(time() - start))
print("Gammatone filter created, time used = {:.4f} seconds".format(time() - start))
else:
pass
# Making everything nn.Prarmeter, so that this model can support nn.DataParallel

if trainable_bins:
gammatone_basis = nn.Parameter(gammatone_basis, requires_grad=trainable_bins)
self.register_parameter('gammatone_basis', gammatone_basis)
else:
self.register_buffer('gammatone_basis', gammatone_basis)

# if trainable_mel==True:
# self.mel_basis = nn.Parameter(self.mel_basis)
# if trainable_STFT==True:
# self.wsin = nn.Parameter(self.wsin)
# self.wcos = nn.Parameter(self.wcos)

def forward(self, x):
x = broadcast_dim(x)
if self.center:
if self.pad_mode == 'constant':
padding = nn.ConstantPad1d(self.n_fft // 2, 0)
elif self.pad_mode == 'reflect':
padding = nn.ReflectionPad1d(self.n_fft // 2)

x = padding(x)

spec = torch.sqrt(conv1d(x, self.wsin, stride=self.stride).pow(2) \
+ conv1d(x, self.wcos, stride=self.stride).pow(2)) ** self.power # Doing STFT by using conv1d

gammatonespec = torch.matmul(self.gammatone_basis, spec)
return gammatonespec
134 changes: 134 additions & 0 deletions Installation/nnAudio/features/griffin_lim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torch.nn as nn
import torch
import numpy as np
from time import time
from ..utils import *


class Griffin_Lim(nn.Module):
"""
Converting Magnitude spectrograms back to waveforms based on the "fast Griffin-Lim"[1].
This Griffin Lim is a direct clone from librosa.griffinlim.
[1] Perraudin, N., Balazs, P., & Søndergaard, P. L. “A fast Griffin-Lim algorithm,”
IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4), Oct. 2013.
Parameters
----------
n_fft : int
The window size. Default value is 2048.
n_iter=32 : int
The number of iterations for Griffin-Lim. The default value is ``32``
hop_length : int
The hop (or stride) size. Default value is ``None`` which is equivalent to ``n_fft//4``.
Please make sure the value is the same as the forward STFT.
window : str
The windowing function for iSTFT. It uses ``scipy.signal.get_window``, please refer to
scipy documentation for possible windowing functions. The default value is 'hann'.
Please make sure the value is the same as the forward STFT.
center : bool
Putting the iSTFT keneral at the center of the time-step or not. If ``False``, the time
index is the beginning of the iSTFT kernel, if ``True``, the time index is the center of
the iSTFT kernel. Default value if ``True``.
Please make sure the value is the same as the forward STFT.
momentum : float
The momentum for the update rule. The default value is ``0.99``.
device : str
Choose which device to initialize this layer. Default value is 'cpu'
"""

def __init__(self,
n_fft,
n_iter=32,
hop_length=None,
win_length=None,
window='hann',
center=True,
pad_mode='reflect',
momentum=0.99,
device='cpu'):
super().__init__()

self.n_fft = n_fft
self.win_length = win_length
self.n_iter = n_iter
self.center = center
self.pad_mode = pad_mode
self.momentum = momentum
self.device = device
if win_length == None:
self.win_length = n_fft
else:
self.win_length = win_length
if hop_length == None:
self.hop_length = n_fft // 4
else:
self.hop_length = hop_length

# Creating window function for stft and istft later
self.w = torch.tensor(get_window(window,
int(self.win_length),
fftbins=True),
device=device).float()

def forward(self, S):
"""
Convert a batch of magnitude spectrograms to waveforms.
Parameters
----------
S : torch tensor
Spectrogram of the shape ``(batch, n_fft//2+1, timesteps)``
"""

assert S.dim() == 3, "Please make sure your input is in the shape of (batch, freq_bins, timesteps)"

# Initializing Random Phase
rand_phase = torch.randn(*S.shape, device=self.device)
angles = torch.empty((*S.shape, 2), device=self.device)
angles[:, :, :, 0] = torch.cos(2 * np.pi * rand_phase)
angles[:, :, :, 1] = torch.sin(2 * np.pi * rand_phase)

# Initializing the rebuilt magnitude spectrogram
rebuilt = torch.zeros(*angles.shape, device=self.device)

for _ in range(self.n_iter):
tprev = rebuilt # Saving previous rebuilt magnitude spec

# spec2wav conversion
# print(f'win_length={self.win_length}\tw={self.w.shape}')
inverse = torch.istft(S.unsqueeze(-1) * angles,
self.n_fft,
self.hop_length,
win_length=self.win_length,
window=self.w,
center=self.center)
# wav2spec conversion
rebuilt = torch.stft(inverse,
self.n_fft,
self.hop_length,
win_length=self.win_length,
window=self.w,
pad_mode=self.pad_mode)

# Phase update rule
angles[:, :, :] = rebuilt[:, :, :] - (self.momentum / (1 + self.momentum)) * tprev[:, :, :]

# Phase normalization
angles = angles.div(torch.sqrt(angles.pow(2).sum(-1)).unsqueeze(-1) + 1e-16) # normalizing the phase

# Using the final phase to reconstruct the waveforms
inverse = torch.istft(S.unsqueeze(-1) * angles,
self.n_fft,
self.hop_length,
win_length=self.win_length,
window=self.w,
center=self.center)
return inverse
Loading

0 comments on commit 7398c9e

Please sign in to comment.