forked from ming024/FastSpeech2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
92 lines (79 loc) · 3.37 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
class FastSpeech2Loss(nn.Module):
""" FastSpeech2 Loss """
def __init__(self, preprocess_config, model_config):
super(FastSpeech2Loss, self).__init__()
self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
"feature"
]
self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
"feature"
]
self.mse_loss = nn.MSELoss()
self.mae_loss = nn.L1Loss()
def forward(self, inputs, predictions):
(
mel_targets,
_,
_,
pitch_targets,
energy_targets,
duration_targets,
) = inputs[6:12]
(
mel_predictions,
postnet_mel_predictions,
pitch_predictions,
energy_predictions,
log_duration_predictions,
_,
src_masks,
mel_masks,
_,
_,
) = predictions
src_masks = ~src_masks
mel_masks = ~mel_masks
log_duration_targets = torch.log(duration_targets.float() + 1)
mel_targets = mel_targets[:, : mel_masks.shape[1], :]
mel_masks = mel_masks[:, :mel_masks.shape[1]]
log_duration_targets.requires_grad = False
pitch_targets.requires_grad = False
energy_targets.requires_grad = False
mel_targets.requires_grad = False
if self.pitch_feature_level == "phoneme_level":
pitch_predictions = pitch_predictions.masked_select(src_masks)
pitch_targets = pitch_targets.masked_select(src_masks)
elif self.pitch_feature_level == "frame_level":
pitch_predictions = pitch_predictions.masked_select(mel_masks)
pitch_targets = pitch_targets.masked_select(mel_masks)
if self.energy_feature_level == "phoneme_level":
energy_predictions = energy_predictions.masked_select(src_masks)
energy_targets = energy_targets.masked_select(src_masks)
if self.energy_feature_level == "frame_level":
energy_predictions = energy_predictions.masked_select(mel_masks)
energy_targets = energy_targets.masked_select(mel_masks)
log_duration_predictions = log_duration_predictions.masked_select(src_masks)
log_duration_targets = log_duration_targets.masked_select(src_masks)
mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
postnet_mel_predictions = postnet_mel_predictions.masked_select(
mel_masks.unsqueeze(-1)
)
mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
mel_loss = self.mae_loss(mel_predictions, mel_targets)
postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
energy_loss = self.mse_loss(energy_predictions, energy_targets)
duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
total_loss = (
mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
)
return (
total_loss,
mel_loss,
postnet_mel_loss,
pitch_loss,
energy_loss,
duration_loss,
)