forked from yuzhou-git/deep-casa
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Yuzhou Liu
committed
Dec 31, 2019
0 parents
commit 5d57679
Showing
7 changed files
with
1,753 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Deep CASA for talker-independent monaural speaker separation | ||
|
||
## Introduction | ||
|
||
This is Tensorflow implementation of "Divide and conquer: A deep CASA approach to talker-independent monaural speaker separation", IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 27, pp. 2092-2102. | ||
|
||
## Contents | ||
|
||
* `./feat/exp_prepare_folder.sh`: prepares folders for experiments. | ||
* `./feat/feat_gen.py`: generates STFT featuers for training, validation and test. | ||
* `./feat/stft.py`: defines STFT and iSTFT. | ||
* `./nn/simul_group.py`: training/validation/test of the simultaneous grouping stage. | ||
* `./nn/seq_group.py`: training/validation/test of the sequential grouping stage. | ||
* `./nn/utility.py`: defines various functions in `simul_group.py` and `seq_group.py`. | ||
|
||
## Experimental setup | ||
|
||
This codebase has been tested on AWS EC2 p3.2xlarge nodes with Deep Learning AMI (Ubuntu 16.04) Version 26.0. | ||
|
||
Follow instructions in turn to set up the environment and run experiments. | ||
|
||
1. Requirements: | ||
* Tensorflow 1.15.0. <br /> | ||
Activate the environment on EC2 : | ||
``` | ||
source activate tensorflow_p27 | ||
``` | ||
* gflags | ||
``` | ||
pip install python-gflags | ||
``` | ||
* Please install other necessary python packages if not using AWS deep Learning AMI (Ubuntu 16.04) Version 26.0. | ||
2. Before running experiments, activate the tensorflow environment on EC2 using: | ||
``` | ||
source activate tensorflow_p27 | ||
``` | ||
3. Generate the WSJ0-2mix dataset using `http://www.merl.com/demos/deep-clustering/create-speaker-mixtures.zip`. Copy the generated files to the EC2 instance. | ||
4. Start feature extraction by running the following command in the main directory: | ||
``` | ||
python feat/feat_gen.py | ||
``` | ||
Thre are two arguments in `feat_gen.py`, `data_folder` and `wav_list_folder`. Change them to where your WSJ0-2mix dataset and file list locate. | ||
5. Train the simultaneous grouping stage using: | ||
``` | ||
TIME_STAMP=train_simul_group | ||
python nn/simul_group.py --time_stamp $TIME_STAMP --is_deploy 0 --batch_size 1 | ||
``` | ||
* Due to utterance-level training and limited GPU memory, `batch_size` can be selected as 1 or 2. | ||
* Change `data_folder` and `wav_list_folder` accordingly. | ||
* You can also change other hyperparameters, e.g., the number of epochs and learning rate, using gflags arguments. | ||
6. Run inference of simultaneous grouping (tt set) using: | ||
``` | ||
RESUME_MODEL=exp/deep_casa_wsj/models/train_simul_group/deep_casa_wsj_model.ckpt_step_1 | ||
python nn/simul_group.py --is_deploy 1 --resume_model $RESUME_MODEL | ||
``` | ||
* `$RESUME_MODEL` is the model to be loaded for inference. Change it accordingly. | ||
* Mixtures, clean references and Dense-UNet estimates will be generated and saved in folder `./exp/deep_casa_wsj/output_tt/files/`. | ||
* Please use your own scripts to generate results in different metrics. | ||
7. Generate temporary .npy file for the next stage (sequential grouping): | ||
``` | ||
RESUME_MODEL=exp/deep_casa_wsj/models/train_simul_group/deep_casa_wsj_model.ckpt_step_1 | ||
python nn/simul_group.py --is_deploy 2 --resume_model $RESUME_MODEL | ||
``` | ||
* Setting `is_deploy` to 2 will generate unorganized estimates by Dense-UNet, and save them as .npy files for the sequential grouping stage. | ||
* tr, cv and tt data are generated in turn, and saved in `./exp/deep_casa_wsj/feat/`. | ||
8. Train the sequential grouping stage using: | ||
``` | ||
TIME_STAMP=train_seq_group | ||
python nn/seq_group.py --time_stamp $TIME_STAMP --is_deploy 0 | ||
``` | ||
Change `data_folder` and `wav_list_folder` accordingly. You can also change other hyperparameters, e.g., the number of epochs and learning rate, using gflags arguments. | ||
9. Run inference of sequential grouping (tt set) using: | ||
``` | ||
RESUME_MODEL=exp/deep_casa_wsj/models/train_seq_group/deep_casa_wsj_model.ckpt_step_1 | ||
python nn/seq_group.py --is_deploy 1 --resume_model $RESUME_MODEL | ||
``` | ||
* `$RESUME_MODEL` is the model to be loaded for inference. Change it accordingly. | ||
* Mixtures, clean references and estimates will be saved in folder `./exp/deep_casa_wsj/output_tt/files/`. | ||
* Please use your own scripts to generate results in different metrics. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
expname=deep_casa_wsj | ||
|
||
basedir=$(pwd) | ||
|
||
exp_dir=$basedir/exp/$expname | ||
|
||
infeat_dir_tr=$basedir/exp/$expname/feat/tr | ||
infeat_dir_cv=$basedir/exp/$expname/feat/cv | ||
infeat_dir_tt=$basedir/exp/$expname/feat/tt | ||
|
||
model_dir=$basedir/exp/$expname/models | ||
|
||
output_tt_files=$basedir/exp/$expname/output_tt/files | ||
|
||
mkdir -p $exp_dir $infeat_dir_tr $infeat_dir_tt $infeat_dir_cv $model_dir $output_tt_files |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import numpy as np | ||
import gflags, os, sys, subprocess | ||
from scipy.io.wavfile import read as wav_read | ||
from stft import stft,istft | ||
|
||
gflags.DEFINE_string('data_folder','/home/ubuntu/data/wsj0_2mix','Path to wsj0-2mix data set') | ||
gflags.DEFINE_string('wav_list_folder','/home/ubuntu/data/wsj0_2mix','Folder that stores wsj0-2mix wav list') | ||
FLAGS = gflags.FLAGS | ||
FLAGS(sys.argv) | ||
|
||
# Define folders | ||
base_folder= os.getcwd() | ||
data_folder = FLAGS.data_folder | ||
wav_list_folder = FLAGS.wav_list_folder | ||
|
||
# Create experiment folder | ||
exp_name='deep_casa_wsj' # the experiment name | ||
exp_folder=base_folder +'/exp/'+ exp_name #the path for the experiment | ||
subprocess.call(base_folder + '/feat/exp_prepare_folder.sh '+ exp_name, shell=True) | ||
|
||
wav_list_prefix = wav_list_folder + '/mix_2_spk_min' | ||
wav_path = data_folder + '/2speakers/wav8k/min/' | ||
feat_path = exp_folder+'/feat' #feature path | ||
|
||
# Generate feature and save as .npy | ||
def get_feat(wav_list_prefix, wav_path, feat_path, task, fftsize=256, hopsize=64): | ||
wav_folders = wav_path + task + '/' | ||
wav_list = wav_list_prefix + '_' +task +'_mix' | ||
output_dir = feat_path + '/' + task + '/' | ||
with open(wav_list, 'r') as f: | ||
for file,line in enumerate(f): | ||
print(task + ' file: ' + str(file+1)) | ||
# Load wav files | ||
line = line.split('\n')[0] | ||
sr,clean_audio_1 = wav_read(wav_folders+'s1/'+line+'.wav') | ||
clean_audio_1 = clean_audio_1.astype('float32')/np.power(2,15) | ||
sr,clean_audio_2 = wav_read(wav_folders+'s2/'+line+'.wav') | ||
clean_audio_2 = clean_audio_2.astype('float32')/np.power(2,15) | ||
sr,mix_audio = wav_read(wav_folders+'mix/'+line+'.wav') | ||
mix_audio = mix_audio.astype('float32')/np.power(2,15) | ||
# STFT | ||
Zxx_1 = stft(clean_audio_1) | ||
Zxx_2 = stft(clean_audio_2) | ||
Zxx_mix = stft(mix_audio) | ||
Zxx_1 = Zxx_1[:,0:(fftsize/2+1)] | ||
Zxx_2 = Zxx_2[:,0:(fftsize/2+1)] | ||
Zxx_mix = Zxx_mix[:,0:(fftsize/2+1)] | ||
# Store real and imaginary STFT of speaker1, speaker2 and mixture | ||
Zxx = np.stack((np.real(Zxx_1).astype('float32'),np.imag(Zxx_1).astype('float32'),np.real(Zxx_2).astype('float32'),np.imag(Zxx_2).astype('float32'),np.real(Zxx_mix).astype('float32'),np.imag(Zxx_mix).astype('float32')),axis=0) | ||
# Save features and targets to npy files | ||
np.save(output_dir+line, Zxx) | ||
# Save time-domain waveform to npy file | ||
audio_len = range(0, len(clean_audio_1)-fftsize+1, hopsize)[-1] + fftsize | ||
audio = np.stack((clean_audio_1[:audio_len], clean_audio_2[:audio_len], mix_audio[:audio_len]), axis=0) | ||
np.save(output_dir+line+'_wave', audio) | ||
|
||
# Feature generation for training, cv and test | ||
get_feat(wav_list_prefix, wav_path, feat_path, 'tr', fftsize=256, hopsize=64) | ||
get_feat(wav_list_prefix, wav_path, feat_path, 'cv', fftsize=256, hopsize=64) | ||
get_feat(wav_list_prefix, wav_path, feat_path, 'tt', fftsize=256, hopsize=64) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import scipy | ||
import scipy.signal | ||
import numpy as np | ||
|
||
def stft(x, framesamp=256, hopsamp=64): | ||
w = scipy.signal.hanning(framesamp, False) | ||
w = np.sqrt(w) | ||
X = scipy.array([scipy.fft(w*x[i:i+framesamp]) | ||
for i in range(0, len(x)-framesamp+1, hopsamp)]) | ||
return X | ||
|
||
def istft(X, T, hopsamp=64): | ||
x = scipy.zeros(T) | ||
weights = scipy.zeros(T) | ||
framesamp = X.shape[1] | ||
w = scipy.signal.hanning(framesamp, False) | ||
w = np.sqrt(w) | ||
|
||
for n,i in enumerate(range(0, len(x)-framesamp+1, hopsamp)): | ||
x[i:i+framesamp] += w*scipy.real(scipy.ifft(X[n])) | ||
weights[i:i+framesamp] += w**2 | ||
|
||
weights[weights==0] = 1 | ||
x = x/weights | ||
|
||
return x |
Oops, something went wrong.