Skip to content

Commit

Permalink
support for non fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
migperfer committed May 2, 2021
1 parent 1441508 commit 2673aac
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions Installation/nnAudio/Spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,8 @@ def __init__(self, sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84,
wsin = torch.tensor(kernel_sin * window)
wcos = torch.tensor(kernel_cos * window)

cqt_kernels_real = torch.tensor(cqt_kernels.real.astype(np.float32))
cqt_kernels_imag = torch.tensor(cqt_kernels.imag.astype(np.float32))
cqt_kernels_real = torch.tensor(cqt_kernels.real)
cqt_kernels_imag = torch.tensor(cqt_kernels.imag)

if trainable_STFT:
wsin = torch.nn.Parameter(wsin, requires_grad=trainable_STFT)
Expand Down Expand Up @@ -1075,8 +1075,8 @@ def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, b
fft_basis = fft(basis)[:,:self.n_fft//2+1] # Convert CQT kenral from time domain to freq domain

# These cqt_kernel is already in the frequency domain
cqt_kernels_real = torch.tensor(fft_basis.real.astype(np.float32))
cqt_kernels_imag = torch.tensor(fft_basis.imag.astype(np.float32))
cqt_kernels_real = torch.tensor(fft_basis.real)
cqt_kernels_imag = torch.tensor(fft_basis.imag)

if verbose==True:
print("CQT kernels created, time used = {:.4f} seconds".format(time()-start))
Expand Down Expand Up @@ -1621,8 +1621,8 @@ def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, f

self.basis = basis
# These cqt_kernel is already in the frequency domain
cqt_kernels_real = torch.tensor(basis.real.astype(np.float32)).unsqueeze(1)
cqt_kernels_imag = torch.tensor(basis.imag.astype(np.float32)).unsqueeze(1)
cqt_kernels_real = torch.tensor(basis.real).unsqueeze(1)
cqt_kernels_imag = torch.tensor(basis.imag).unsqueeze(1)

if trainable:
cqt_kernels_real = torch.nn.Parameter(cqt_kernels_real, requires_grad=trainable)
Expand Down Expand Up @@ -2177,7 +2177,7 @@ def __init__(self,fr=2, fs=16000, hop_length=320,
self.f = fs*np.linspace(0, 0.5, np.round(self.N//2), endpoint=True) # it won't be used but will be returned
self.pad_value = ((self.N-window_size))
# Create window function, always blackmanharris?
h = scipy.signal.blackmanharris(window_size).astype(np.float32) # window function for STFT
h = scipy.signal.blackmanharris(window_size) # window function for STFT
self.register_buffer('h',torch.tensor(h))

# variables for CFP
Expand All @@ -2194,8 +2194,8 @@ def __init__(self,fr=2, fs=16000, hop_length=320,

# filters for the final step
freq2logfreq_matrix, quef2logfreq_matrix = self.create_logfreq_matrix(self.f, self.q, fr, fc, tc, NumPerOct, fs)
self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix.astype(np.float32)))
self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix.astype(np.float32)))
self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix))
self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix))

def _CFP(self, spec):
spec = torch.relu(spec).pow(self.g[0])
Expand Down Expand Up @@ -2318,7 +2318,7 @@ def __init__(self,fr=2, fs=16000, hop_length=320,
self.f = fs*np.linspace(0, 0.5, np.round(self.N//2), endpoint=True) # it won't be used but will be returned
self.pad_value = ((self.N-window_size))
# Create window function, always blackmanharris?
h = scipy.signal.blackmanharris(window_size).astype(np.float32) # window function for STFT
h = scipy.signal.blackmanharris(window_size) # window function for STFT
self.register_buffer('h',torch.tensor(h))

# variables for CFP
Expand All @@ -2335,8 +2335,8 @@ def __init__(self,fr=2, fs=16000, hop_length=320,

# filters for the final step
freq2logfreq_matrix, quef2logfreq_matrix = self.create_logfreq_matrix(self.f, self.q, fr, fc, tc, NumPerOct, fs)
self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix.astype(np.float32)))
self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix.astype(np.float32)))
self.register_buffer('freq2logfreq_matrix',torch.tensor(freq2logfreq_matrix))
self.register_buffer('quef2logfreq_matrix',torch.tensor(quef2logfreq_matrix))

def _CFP(self, spec):
spec = torch.relu(spec).pow(self.g[0])
Expand Down

0 comments on commit 2673aac

Please sign in to comment.