Skip to content

Commit

Permalink
Paths, comments and README update. Addition of test file.
Browse files Browse the repository at this point in the history
  • Loading branch information
S.I. Mimilakis committed Dec 5, 2017
1 parent 499471c commit 88a7644
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 57 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Listening Examples : https://js-mim.github.io/mss_pytorch/
- TorchVision : torchvision==0.1.9
- Other : wave(used for wav file reading), pyglet(used only for audio playback), pickle(for storing some results)
- Trained Models : https://doi.org/10.5281/zenodo.1064805 [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.1064805.svg)](https://doi.org/10.5281/zenodo.1064805)
Download and place them under "results/results_inference/"
- MIR_Eval : mir_eval=='0.4' (This is used only for unofficial cross-validation. For the reported evaluation please refer to: https://github.com/faroit/dsdtools)


Expand Down
6 changes: 4 additions & 2 deletions helpers/io_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class AudioIO:
IO.AudioIO.sound(x,fs)
"""
# Normalisation parameters for wavreading and writing
# Normalisation parameters for wavreading and writing
normFact = {'int8' : (2**7) -1,
'int16': (2**15)-1,
'int24': (2**23)-1,
Expand Down Expand Up @@ -416,4 +416,6 @@ def stop():

# Listen to stereo processed
AudioIO.sound(x2*g,fs)
AudioIO.audioWrite(x2, fs, 16, 'myNewWavFile.wav', 'wav')
AudioIO.audioWrite(x2, fs, 16, 'myNewWavFile.wav', 'wav')

# EOF
72 changes: 36 additions & 36 deletions helpers/masking_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
__copyright__ = 'MacSeNet'

import numpy as np
from scipy.fftpack import fft, ifft
from tf_methods import TimeFrequencyDecomposition as TF
from scipy.fftpack import fft


class FrequencyMasking:
"""Class containing various time-frequency masking methods, for processing Time-Frequency representations.
"""Class containing various time-frequency masking methods,
for processing Time-Frequency representations.
"""

def __init__(self, mX, sTarget, nResidual, psTarget = [], pnResidual = [], alpha = 1.2, method = 'Wiener'):
Expand All @@ -28,68 +29,66 @@ def __init__(self, mX, sTarget, nResidual, psTarget = [], pnResidual = [], alpha
self._amountiter = 0

def __call__(self, reverse = False):

if (self._method == 'Phase'):
if self._method == 'Phase':
if not self._pTarget.size or not self._pTarget.size:
raise ValueError('Phase-sensitive masking cannot be performed without phase information.')
else:
FrequencyMasking.phaseSensitive(self)
if not(reverse) :
if not reverse :
FrequencyMasking.applyMask(self)
else :
FrequencyMasking.applyReverseMask(self)

elif (self._method == 'IRM'):
elif self._method == 'IRM':
FrequencyMasking.IRM(self)
if not(reverse) :
if not reverse:
FrequencyMasking.applyMask(self)
else :
FrequencyMasking.applyReverseMask(self)

elif (self._method == 'IAM'):
elif self._method == 'IAM':
FrequencyMasking.IAM(self)
if not(reverse) :
if not reverse:
FrequencyMasking.applyMask(self)
else :
FrequencyMasking.applyReverseMask(self)

elif (self._method == 'IBM'):
elif self._method == 'IBM':
FrequencyMasking.IBM(self)
if not(reverse) :
if not reverse:
FrequencyMasking.applyMask(self)
else :
FrequencyMasking.applyReverseMask(self)

elif (self._method == 'UBBM'):
elif self._method == 'UBBM':
FrequencyMasking.UBBM(self)
if not(reverse) :
if not reverse:
FrequencyMasking.applyMask(self)
else :
FrequencyMasking.applyReverseMask(self)


elif (self._method == 'Wiener'):
elif self._method == 'Wiener':
FrequencyMasking.Wiener(self)
if not(reverse) :
if not reverse:
FrequencyMasking.applyMask(self)
else :
FrequencyMasking.applyReverseMask(self)

elif (self._method == 'alphaWiener'):
elif self._method == 'alphaWiener':
FrequencyMasking.alphaHarmonizableProcess(self)
if not(reverse) :
if not reverse:
FrequencyMasking.applyMask(self)
else :
FrequencyMasking.applyReverseMask(self)

elif (self._method == 'expMask'):
elif self._method == 'expMask':
FrequencyMasking.ExpM(self)
if not(reverse) :
if not reverse:
FrequencyMasking.applyMask(self)
else :
FrequencyMasking.applyReverseMask(self)

elif (self._method == 'MWF'):
elif self._method == 'MWF':
print('Multichannel Wiener Filtering')
FrequencyMasking.MWF(self)

Expand Down Expand Up @@ -427,29 +426,30 @@ def applyReverseMask(self):
def _IS(self, Xhat):
""" Compute the Itakura-Saito distance between the observed magnitude spectrum
and the estimated one.
Args:
mX : (2D ndarray) Input Magnitude Spectrogram
Xhat : (2D ndarray) Estimated Magnitude Spectrogram
Returns:
dis : (float) Average Itakura-Saito distance
"""
Args:
mX : (2D ndarray) Input Magnitude Spectrogram
Xhat : (2D ndarray) Estimated Magnitude Spectrogram
Returns:
dis : (float) Average Itakura-Saito distance
"""
r1 = (np.abs(self._mX)**self._alpha + self._eps) / (np.abs(Xhat) + self._eps)
lg = np.log((np.abs(self._mX)**self._alpha + self._eps)) - np.log((np.abs(Xhat) + self._eps))
return np.mean(r1 - lg - 1.)

def _dIS(self, Xhat):
""" Computation of the first derivative of Itakura-Saito function. As appears in :
Cedric Fevotte and Jerome Idier, "Algorithms for nonnegative matrix factorization
with the beta-divergence", in CoRR, vol. abs/1010.1763, 2010.
Args:
mX : (2D ndarray) Input Magnitude Spectrogram
Xhat : (2D ndarray) Estimated Magnitude Spectrogram
Returns:
dis' : (float) Average of first derivative of Itakura-Saito distance.
"""
Cedric Fevotte and Jerome Idier, "Algorithms for nonnegative matrix factorization
with the beta-divergence", in CoRR, vol. abs/1010.1763, 2010.
Args:
mX : (2D ndarray) Input Magnitude Spectrogram
Xhat : (2D ndarray) Estimated Magnitude Spectrogram
Returns:
dis' : (float) Average of first derivative of Itakura-Saito distance.
"""
dis = (np.abs(Xhat + self._eps) ** (-2.)) * (np.abs(Xhat) - np.abs(self._mX)**self._alpha)
return (np.mean(dis))


if __name__ == "__main__":

# Small test
Expand Down
69 changes: 60 additions & 9 deletions helpers/nnet_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@
__copyright__ = 'MacSeNet'

# imports
from .io_methods import AudioIO as Io
from .masking_methods import FrequencyMasking as Fm
from helpers.io_methods import AudioIO as Io
from helpers.masking_methods import FrequencyMasking as Fm
from mir_eval import separation as bss_eval
from numpy.lib import stride_tricks
from helpers import iterative_inference as it_infer
from losses import loss_functions as loss
import matplotlib.pyplot as plt
import pickle as pickle
import tf_methods as tf
import numpy as np
import os

# definitions
mixtures_path = '/home/avdata/audio/own/dsd100/DSD100/Mixtures/'
sources_path = '/home/avdata/audio/own/dsd100/DSD100/Sources/'
mixtures_path = 'DSD100/Mixtures/'
sources_path = 'DSD100/Sources/'
keywords = ['bass.wav', 'drums.wav', 'other.wav', 'vocals.wav', 'mixture.wav']
foldersList = ['Dev', 'Test']
save_path = 'results/GRU_sskip_filt/inference_m3_i10plus/'
Expand All @@ -31,6 +29,26 @@


def prepare_overlap_sequences(ms, vs, bk, l_size, o_lap, bsize):
"""
Method to prepare overlapping sequences of the given magnitude spectra.
Args:
ms : (2D Array) Mixture magnitude spectra (Time frames times Frequency sub-bands).
vs : (2D Array) Singing voice magnitude spectra (Time frames times Frequency sub-bands).
bk : (2D Array) Background magnitude spectra (Time frames times Frequency sub-bands).
l_size : (int) Length of the time-sequence.
o_lap : (int) Overlap between spectrogram time-sequences
(to recover the missing information from the context information).
bsize : (int) Batch size.
Returns:
ms : (3D Array) Mixture magnitude spectra training data
reshaped into overlapping sequences.
vs : (3D Array) Singing voice magnitude spectra training data
reshaped into overlapping sequences.
bk : (3D Array) Background magnitude spectra training data
reshaped into overlapping sequences.
"""
trim_frame = ms.shape[0] % (l_size - o_lap)
trim_frame -= (l_size - o_lap)
trim_frame = np.abs(trim_frame)
Expand Down Expand Up @@ -68,6 +86,12 @@ def get_data(current_set, set_size, wsz=2049, N=4096, hop=384, T=100, L=20, B=16
Args:
current_set : (int) An integer denoting the current training set.
set_size : (int) The amount of files a set has.
wsz : (int) Window size in samples.
N : (int) The FFT size.
hop : (int) Hop size in samples.
T : (int) Length of the time-sequence.
L : (int) Number of context frames from the time-sequence.
B : (int) Batch size.
Returns:
ms_train : (3D Array) Mixture magnitude training data, for the current set.
Expand Down Expand Up @@ -127,6 +151,20 @@ def get_data(current_set, set_size, wsz=2049, N=4096, hop=384, T=100, L=20, B=16


def test_eval(nnet, B, T, N, L, wsz, hop):
"""
Method to test the model on the test data. Writes the outcomes in ".wav" format and.
stores them under the defined results path. Optionally, it performs BSS-Eval using
MIREval python toolbox (Used only for comparison to BSSEval Matlab implementation).
The evaluation results are stored under the defined save path.
Args:
nnet : (List) A list containing the Pytorch modules of the skip-filtering model.
B : (int) Batch size.
T : (int) Length of the time-sequence.
N : (int) The FFT size.
L : (int) Number of context frames from the time-sequence.
wsz : (int) Window size in samples.
hop : (int) Hop size in samples.
"""
nnet[0].eval()
nnet[1].eval()
nnet[2].eval()
Expand Down Expand Up @@ -234,14 +272,27 @@ def my_res(mx, vx, L, wsz):


def test_nnet(nnet, seqlen=100, olap=40, wsz=2049, N=4096, hop=384, B=16):
"""
Method to test the model on some data. Writes the outcomes in ".wav" format and.
stores them under the defined results path.
Args:
nnet : (List) A list containing the Pytorch modules of the skip-filtering model.
seqlen : (int) Length of the time-sequence.
olap : (int) Overlap between spectrogram time-sequences
(to recover the missing information from the context information).
wsz : (int) Window size in samples.
N : (int) The FFT size.
hop : (int) Hop size in samples.
B : (int) Batch size.
"""
nnet[0].eval()
nnet[1].eval()
nnet[2].eval()
nnet[3].eval()
L = olap/2
seg = 2
w = tf.hamming(wsz, True)
x, fs = Io.wavRead('/home/mis/Documents/Python/Projects/SourceSeparation/testFiles/supreme_test3.wav', mono=True)
x, fs = Io.wavRead('results/test_files/test.wav', mono=True)

mx, px = tf.TimeFrequencyDecomposition.STFT(x, w, N, hop)

Expand Down Expand Up @@ -282,8 +333,8 @@ def test_nnet(nnet, seqlen=100, olap=40, wsz=2049, N=4096, hop=384, B=16):

x = x[olap/2 * hop:]

Io.audioWrite(y_recb, 44100, 16, 'results/test_sv.mp3', 'mp3')
Io.audioWrite(x[:len(y_recb)], 44100, 16, 'results/test_mix.mp3', 'mp3')
Io.wavWrite(y_recb, 44100, 16, 'results/test_files/test_sv.wav')
Io.wavWrite(x[:len(y_recb)], 44100, 16, 'results/test_files/test_mix.wav')

return None

Expand Down
4 changes: 4 additions & 0 deletions helpers/tf_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
__author__ = 'S.I. Mimilakis'
__copyright__ = 'MacSeNet'

# imports
import math
import numpy as np
from scipy.fftpack import fft, ifft
from scipy.signal import hamming

# definition
eps = np.finfo(np.float32).tiny


class TimeFrequencyDecomposition:
""" A Class that performs time-frequency decompositions by means of a
Discrete Fourier Transform, using Fast Fourier Transform algorithm
by SciPy, MDCT with modified type IV bases, PQMF,
and Fractional Fast Fourier Transform.
"""

@staticmethod
def DFT(x, w, N):
""" Discrete Fourier Transformation(Analysis) of a given real input signal
Expand Down
19 changes: 9 additions & 10 deletions processes_scripts/main_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tqdm import tqdm
from helpers import visualize, nnet_helpers
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau as RedLR
from modules import cls_sparse_skip_filt as s_s_net
from losses import loss_functions
from helpers import iterative_inference as it_infer
Expand Down Expand Up @@ -44,7 +43,7 @@ def main(training, apply_sparsity):
mask_loss_threshold = 1.5 # Scalar indicating the threshold for the time-frequency masking module
good_loss_threshold = 0.25 # Scalar indicating the threshold for the source enhancment module

# Data
# Data (Predifined by the DSD100 dataset and the non-instumental/non-bleeding stems of MedleydB)
totTrainFiles = 116
numFilesPerTr = 4

Expand Down Expand Up @@ -164,26 +163,26 @@ def main(training, apply_sparsity):
else:
print('------- Loading pre-trained model -------')
print('------- Loading inference weights -------')
encoder.load_state_dict(torch.load('results/results_inference/torch_sps_encoder_40_m3_i10.pytorch'))
decoder.load_state_dict(torch.load('results/results_inference/torch_sps_decoder_40_m3_i10.pytorch'))
sp_decoder.load_state_dict(torch.load('results/results_inference/torch_sps_sp_decoder_40_m3_i10.pytorch'))
source_enhancement.load_state_dict(torch.load('results/results_inference/torch_sps_se_40_m3_i10.pytorch'))
encoder.load_state_dict(torch.load('results/results_inference/torch_sps_encoder.pytorch'))
decoder.load_state_dict(torch.load('results/results_inference/torch_sps_decoder.pytorch'))
sp_decoder.load_state_dict(torch.load('results/results_inference/torch_sps_sp_decoder.pytorch'))
source_enhancement.load_state_dict(torch.load('results/results_inference/torch_sps_se.pytorch'))
print('------------- Done -------------')

return encoder, decoder, sp_decoder, source_enhancement


if __name__ == '__main__':
training = True # Whether to train or test the trained model (requires the optimized parameters)
training = False # Whether to train or test the trained model (requires the optimized parameters)
apply_sparsity = True # Whether to apply a sparse penalty or not

sfiltnet = main(training, apply_sparsity)

#print('------------- BSS-Eval -------------')
#nnet_helpers.test_eval(sfiltnet, 16, 60, 4096, 10, 2049, 384)
#print('------------- Done -------------')
#print('------------- DNN-Test -------------')
#nnet_helpers.test_nnet(sfiltnet, 60, 10*2, 2049, 4096, 384, 16)
#print('------------- Done -------------')
print('------------- DNN-Test -------------')
nnet_helpers.test_nnet(sfiltnet, 60, 10*2, 2049, 4096, 384, 16)
print('------------- Done -------------')

# EOF
Binary file added results/test_files/test.wav
Binary file not shown.

0 comments on commit 88a7644

Please sign in to comment.