-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathloss.py
203 lines (169 loc) · 8.15 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import torch
import torch.nn as nn
from torch.nn import functional as F
from common import get_mask_from_lengths
def compute_flow_loss(z, log_det_W_list, log_s_list, n_elements, n_dims, mask,
sigma=1.0):
log_det_W_total = 0.0
for i, log_s in enumerate(log_s_list):
if i == 0:
log_s_total = torch.sum(log_s * mask)
if len(log_det_W_list):
log_det_W_total = log_det_W_list[i]
else:
log_s_total = log_s_total + torch.sum(log_s * mask)
if len(log_det_W_list):
log_det_W_total += log_det_W_list[i]
if len(log_det_W_list):
log_det_W_total *= n_elements
z = z * mask
prior_NLL = torch.sum(z*z)/(2*sigma*sigma)
loss = prior_NLL - log_s_total - log_det_W_total
denom = n_elements * n_dims
loss = loss / denom
loss_prior = prior_NLL / denom
return loss, loss_prior
def compute_regression_loss(x_hat, x, mask, name=False):
x = x[:, None] if len(x.shape) == 2 else x # add channel dim
mask = mask[:, None] if len(mask.shape) == 2 else mask # add channel dim
assert len(x.shape) == len(mask.shape)
x = x * mask
x_hat = x_hat * mask
if name == 'vpred':
loss = F.binary_cross_entropy_with_logits(x_hat, x, reduction='sum')
else:
loss = F.mse_loss(x_hat, x, reduction='sum')
loss = loss / mask.sum()
loss_dict = {"loss_{}".format(name): loss}
return loss_dict
class AttributePredictionLoss(torch.nn.Module):
def __init__(self, name, model_config, loss_weight, sigma=1.0):
super(AttributePredictionLoss, self).__init__()
self.name = name
self.sigma = sigma
self.model_name = model_config['name']
self.loss_weight = loss_weight
self.n_group_size = 1
if 'n_group_size' in model_config['hparams']:
self.n_group_size = model_config['hparams']['n_group_size']
def forward(self, model_output, lens):
mask = get_mask_from_lengths(lens // self.n_group_size)
mask = mask[:, None].float()
loss_dict = {}
if 'z' in model_output:
n_elements = lens.sum() // self.n_group_size
n_dims = model_output['z'].size(1)
loss, loss_prior = compute_flow_loss(
model_output['z'], model_output['log_det_W_list'],
model_output['log_s_list'], n_elements, n_dims, mask,
self.sigma)
loss_dict = {"loss_{}".format(self.name): (loss, self.loss_weight),
"loss_prior_{}".format(self.name): (loss_prior, 0.0)}
elif 'x_hat' in model_output:
loss_dict = compute_regression_loss(
model_output['x_hat'], model_output['x'], mask, self.name)
for k, v in loss_dict.items():
loss_dict[k] = (v, self.loss_weight)
if len(loss_dict) == 0:
raise Exception("loss not supported")
return loss_dict
class AttentionCTCLoss(torch.nn.Module):
def __init__(self, blank_logprob=-1):
super(AttentionCTCLoss, self).__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.blank_logprob = blank_logprob
self.CTCLoss = nn.CTCLoss(zero_infinity=True)
def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = F.pad(
input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0),
value=self.blank_logprob)
cost_total = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid]+1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
:query_lens[bid], :, :key_lens[bid]+1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
ctc_cost = self.CTCLoss(curr_logprob, target_seq,
input_lengths=query_lens[bid:bid+1],
target_lengths=key_lens[bid:bid+1])
cost_total += ctc_cost
cost = cost_total/attn_logprob.shape[0]
return cost
class AttentionBinarizationLoss(torch.nn.Module):
def __init__(self):
super(AttentionBinarizationLoss, self).__init__()
def forward(self, hard_attention, soft_attention):
log_sum = torch.log(soft_attention[hard_attention == 1]).sum()
return -log_sum / hard_attention.sum()
class RADTTSLoss(torch.nn.Module):
def __init__(self, sigma=1.0, n_group_size=1, dur_model_config=None,
f0_model_config=None, energy_model_config=None,
vpred_model_config=None, loss_weights=None):
super(RADTTSLoss, self).__init__()
self.sigma = sigma
self.n_group_size = n_group_size
self.loss_weights = loss_weights
self.attn_ctc_loss = AttentionCTCLoss(
blank_logprob=loss_weights.get('blank_logprob', -1))
self.loss_fns = {}
if dur_model_config is not None:
self.loss_fns['duration_model_outputs'] = AttributePredictionLoss(
'duration', dur_model_config, loss_weights['dur_loss_weight'])
if f0_model_config is not None:
self.loss_fns['f0_model_outputs'] = AttributePredictionLoss(
'f0', f0_model_config, loss_weights['f0_loss_weight'],
sigma=1.0)
if energy_model_config is not None:
self.loss_fns['energy_model_outputs'] = AttributePredictionLoss(
'energy',
energy_model_config, loss_weights['energy_loss_weight'])
if vpred_model_config is not None:
self.loss_fns['vpred_model_outputs'] = AttributePredictionLoss(
'vpred', vpred_model_config, loss_weights['vpred_loss_weight'])
def forward(self, model_output, in_lens, out_lens):
loss_dict = {}
if len(model_output['z_mel']):
n_elements = out_lens.sum() // self.n_group_size
mask = get_mask_from_lengths(out_lens // self.n_group_size)
mask = mask[:, None].float()
n_dims = model_output['z_mel'].size(1)
loss_mel, loss_prior_mel = compute_flow_loss(
model_output['z_mel'], model_output['log_det_W_list'],
model_output['log_s_list'], n_elements, n_dims, mask,
self.sigma)
loss_dict['loss_mel'] = (loss_mel, 1.0) # loss, weight
loss_dict['loss_prior_mel'] = (loss_prior_mel, 0.0)
ctc_cost = self.attn_ctc_loss(
model_output['attn_logprob'], in_lens, out_lens)
loss_dict['loss_ctc'] = (
ctc_cost, self.loss_weights['ctc_loss_weight'])
for k in model_output:
if k in self.loss_fns:
if model_output[k] is not None and len(model_output[k]) > 0:
t_lens = in_lens if 'dur' in k else out_lens
mout = model_output[k]
for loss_name, v in self.loss_fns[k](mout, t_lens).items():
loss_dict[loss_name] = v
return loss_dict