forked from ming024/FastSpeech2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
40 lines (32 loc) · 1.36 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import hparams as hp
def mse_loss(prediction, target, length):
batch_size = target.shape[0]
loss = 0
for p, t, l in zip(prediction, target, length):
loss += torch.mean((p[:l]-t[:l])**2)
loss /= batch_size
return loss
def mae_loss(prediction, target, length):
batch_size = target.shape[0]
loss = 0
for p, t, l in zip(prediction, target, length):
loss += torch.mean(torch.abs(p[:l]-t[:l]))
loss /= batch_size
return loss
class FastSpeech2Loss(nn.Module):
""" FastSpeech2 Loss """
def __init__(self):
super(FastSpeech2Loss, self).__init__()
def forward(self, d_predicted, d_target, p_predicted, p_target, e_predicted, e_target, mel, mel_postnet, mel_target, src_length, mel_length):
d_target.requires_grad = False
p_target.requires_grad = False
e_target.requires_grad = False
mel_target.requires_grad = False
mel_loss = mse_loss(mel, mel_target, mel_length)
mel_postnet_loss = mse_loss(mel_postnet, mel_target, mel_length)
d_loss = mae_loss(d_predicted, d_target.float(), src_length)
p_loss = mae_loss(p_predicted, p_target, mel_length)
e_loss = mae_loss(e_predicted, e_target, mel_length)
return mel_loss, mel_postnet_loss, d_loss, p_loss, e_loss