Skip to content

Commit

Permalink
Removed the arguments from the unused residual connections.
Browse files Browse the repository at this point in the history
  • Loading branch information
Js-Mim committed Nov 21, 2017
1 parent 12bc40e commit 08415b0
Showing 1 changed file with 24 additions and 30 deletions.
54 changes: 24 additions & 30 deletions processes_scripts/main_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,24 @@ def main(training, apply_sparsity):
fs = 44100 # Sampling frequency

# Parameters
B = 16 # Batch-size
T = 60 # Length of the sequence
N = 2049 # Frequency sub-bands to be processed
F = 744 # Frequency sub-bands for encoding
L = 10 # Context parameter (2*L frames will be removed)
epochs = 100 # Epochs
init_lr = 1e-4 # Initial learning rate
mnorm = 0.5 # L2-based norm clipping
mask_loss_threshold = 1.5 # Scalar indicating the threshold for the time-frequency masking module
good_loss_threshold = 0.25 # Scalar indicating the threshold for the source enhancment module
B = 16 # Batch-size
T = 60 # Length of the sequence
N = 2049 # Frequency sub-bands to be processed
F = 744 # Frequency sub-bands for encoding
L = 10 # Context parameter (2*L frames will be removed)
epochs = 100 # Epochs
init_lr = 1e-4 # Initial learning rate
mnorm = 0.5 # L2-based norm clipping
mask_loss_threshold = 1.5 # Scalar indicating the threshold for the time-frequency masking module
good_loss_threshold = 0.25 # Scalar indicating the threshold for the source enhancment module

# Data
totTrainFiles = 116
numFilesPerTr = 4

print('------------ Building model ------------')
encoder = s_s_net.BiGRUEncoder(B, T, N, F, L)
decoder = s_s_net.Decoder(B, T, N, F, L)
decoder = s_s_net.Decoder(B, T, N, F, L, infr=True)
sp_decoder = s_s_net.SparseDecoder(B, T, N, F, L)
source_enhancement = s_s_net.SourceEnhancement(B, T, N, F, L)

Expand All @@ -75,8 +75,6 @@ def main(training, apply_sparsity):
lr=init_lr
)

scheduler = RedLR(optimizer, 'min', factor=0.1, patience=5, verbose=True)

if training:
win_viz, winb_viz = visualize.init_visdom()
batch_loss = []
Expand All @@ -100,9 +98,8 @@ def main(training, apply_sparsity):
# Mixture to Singing voice
H_enc = encoder(ms[batch * B: (batch+1)*B, :, :])
# Iterative inference
H_j_dec = it_infer.iterative_recurrent_inference(decoder, H_enc, ms[batch * B: (batch+1)*B, :, :],
criterion=None, tol=1e-3, max_iter=10)

H_j_dec = it_infer.iterative_recurrent_inference(decoder, H_enc,
criterion=None, tol=1e-3, max_iter=10)
vs_hat_b = sp_decoder(H_j_dec, ms[batch * B: (batch+1)*B, :, :])[0]
vs_hat_b_filt = source_enhancement(vs_hat_b)

Expand Down Expand Up @@ -147,25 +144,22 @@ def main(training, apply_sparsity):
loss.backward()
torch.nn.utils.clip_grad_norm(list(encoder.parameters()) +
list(decoder.parameters()) +
list(sp_decoder.parameters())+
list(sp_decoder.parameters()) +
list(source_enhancement.parameters()),
max_norm = mnorm, norm_type=2)
max_norm=mnorm, norm_type=2)
optimizer.step()
# Update graphs
win_viz = visualize.viz.line(X=np.arange(batch_index, batch_index+1),
Y=np.reshape(batch_loss[batch_index], (1,)),
win=win_viz, update='append')
batch_index += 1

#if (epoch + 1) >= 5:
# scheduler.step(Variable(torch.from_numpy(np.asarray(np.mean(epoch_loss)))))

if (epoch+1) % 40 == 0:
print('------------ Saving model ------------')
torch.save(encoder.state_dict(), 'results/results_inference/torch_sps_encoder_' + str(epoch+1)+'.pytorch')
torch.save(decoder.state_dict(), 'results/results_inference/torch_sps_decoder_' + str(epoch+1)+'.pytorch')
torch.save(sp_decoder.state_dict(), 'results/results_inference/torch_sps_sp_decoder_' + str(epoch+1)+'.pytorch')
torch.save(source_enhancement.state_dict(), 'results/results_inference/torch_sps_se_' + str(epoch+1)+'.pytorch')
torch.save(encoder.state_dict(), 'results/torch_sps_encoder_' + str(epoch+1)+'.pytorch')
torch.save(decoder.state_dict(), 'results/torch_sps_decoder_' + str(epoch+1)+'.pytorch')
torch.save(sp_decoder.state_dict(), 'results/torch_sps_sp_decoder_' + str(epoch+1)+'.pytorch')
torch.save(source_enhancement.state_dict(), 'results/torch_sps_se_' + str(epoch+1)+'.pytorch')
print('------------ Done ------------')
else:
print('------- Loading pre-trained model -------')
Expand All @@ -174,22 +168,22 @@ def main(training, apply_sparsity):
decoder.load_state_dict(torch.load('results/results_inference/torch_sps_decoder_40_m3_i10.pytorch'))
sp_decoder.load_state_dict(torch.load('results/results_inference/torch_sps_sp_decoder_40_m3_i10.pytorch'))
source_enhancement.load_state_dict(torch.load('results/results_inference/torch_sps_se_40_m3_i10.pytorch'))

print('------------- Done -------------')

return encoder, decoder, sp_decoder, source_enhancement


if __name__ == '__main__':
training = False # Whether to train or test the trained model (requires the optimized parameters)
training = True # Whether to train or test the trained model (requires the optimized parameters)
apply_sparsity = True # Whether to apply a sparse penalty or not

sfiltnet = main(training, apply_sparsity)

#print('------------- BSS-Eval -------------')
#nnet_helpers.test_eval(sfiltnet, 16, 60, 4096, 10, 2049, 384)
#print('------------- Done -------------')
print('------------- DNN-Test -------------')
nnet_helpers.test_nnet(sfiltnet, 60, 10*2, 2049, 4096, 384, 16)
print('------------- Done -------------')
#print('------------- DNN-Test -------------')
#nnet_helpers.test_nnet(sfiltnet, 60, 10*2, 2049, 4096, 384, 16)
#print('------------- Done -------------')

# EOF

0 comments on commit 08415b0

Please sign in to comment.