Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ruizhecao96 authored Jul 6, 2022
1 parent ece8f9c commit ec19329
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def test_step(self, batch):

time_loss = torch.mean(torch.abs(est_audio - clean))
length = est_audio.size(-1)
loss = 0.1 * loss_ri + 0.9 * loss_mag + 0.2 * time_loss + 0.05 * gen_loss_GAN
loss = args.loss_weights[0] * loss_ri + args.loss_weights[1] * loss_mag + args.loss_weights[2] * time_loss \
+ args.loss_weights[3] * gen_loss_GAN

est_audio_list = list(est_audio.detach().cpu().numpy())
clean_audio_list = list(clean.cpu().numpy()[:, :length])
Expand Down

0 comments on commit ec19329

Please sign in to comment.