Skip to content

Commit

Permalink
Implement ForwardTTSLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Sep 10, 2021
1 parent 3abc3a1 commit 570d597
Showing 1 changed file with 61 additions and 60 deletions.
121 changes: 61 additions & 60 deletions TTS/tts/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,40 @@ def forward(self, x, y, length=None):
y: B x T
length: B
"""
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float()
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float()
return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum()


class ForwardSumLoss(nn.Module):
def __init__(self, blank_logprob=-1):
super().__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
self.blank_logprob = blank_logprob

def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)

total_loss = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]

curr_logprob = self.log_softmax(curr_logprob[None])[0]
loss = self.ctc_loss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
total_loss = total_loss + loss

total_loss = total_loss / attn_logprob.shape[0]
return total_loss


########################
# MODEL LOSS LAYERS
########################
Expand Down Expand Up @@ -413,25 +443,6 @@ def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x
return return_dict


class SpeedySpeechLoss(nn.Module):
def __init__(self, c):
super().__init__()
self.l1 = L1LossMasked(False)
self.ssim = SSIMLoss()
self.huber = Huber()

self.ssim_alpha = c.ssim_alpha
self.huber_alpha = c.huber_alpha
self.l1_alpha = c.l1_alpha

def forward(self, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens):
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss
return {"loss": loss, "loss_l1": l1_loss, "loss_ssim": ssim_loss, "loss_dur": huber_loss}


def mse_loss_custom(x, y):
"""MSE loss using the torch back-end without reduction.
It uses less VRAM than the raw code"""
Expand Down Expand Up @@ -660,51 +671,41 @@ def forward(self, scores_disc_real, scores_disc_fake):
return return_dict


class ForwardSumLoss(nn.Module):
def __init__(self, blank_logprob=-1):
super().__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
self.blank_logprob = blank_logprob
class ForwardTTSLoss(nn.Module):
"""Generic configurable ForwardTTS loss."""

def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)

total_loss = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]

curr_logprob = self.log_softmax(curr_logprob[None])[0]
loss = self.ctc_loss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
total_loss = total_loss + loss

total_loss = total_loss / attn_logprob.shape[0]
return total_loss


class FastPitchLoss(nn.Module):
def __init__(self, c):
super().__init__()
self.spec_loss = MSELossMasked(False)
self.ssim = SSIMLoss()
self.dur_loss = MSELossMasked(False)
self.pitch_loss = MSELossMasked(False)
if c.spec_loss_type == "mse":
self.spec_loss = MSELossMasked(False)
elif c.spec_loss_type == "l1":
self.spec_loss = L1LossMasked(False)
else:
raise ValueError(" [!] Unknown spec_loss_type {}".format(c.spec_loss_type))

if c.duration_loss_type == "mse":
self.dur_loss = MSELossMasked(False)
elif c.duration_loss_type == "l1":
self.dur_loss = L1LossMasked(False)
elif c.duration_loss_type == "huber":
self.dur_loss = Huber()
else:
raise ValueError(" [!] Unknown duration_loss_type {}".format(c.duration_loss_type))

if c.model_args.use_aligner:
self.aligner_loss = ForwardSumLoss()
self.aligner_loss_alpha = c.aligner_loss_alpha

if c.model_args.use_pitch:
self.pitch_loss = MSELossMasked(False)
self.pitch_loss_alpha = c.pitch_loss_alpha

if c.use_ssim_loss:
self.ssim = SSIMLoss() if c.use_ssim_loss else None
self.ssim_loss_alpha = c.ssim_loss_alpha

self.spec_loss_alpha = c.spec_loss_alpha
self.ssim_loss_alpha = c.ssim_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.pitch_loss_alpha = c.pitch_loss_alpha
self.aligner_loss_alpha = c.aligner_loss_alpha
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha

@staticmethod
Expand All @@ -731,7 +732,7 @@ def forward(
):
loss = 0
return_dict = {}
if self.ssim_loss_alpha > 0:
if hasattr(self, "ssim_loss") and self.ssim_loss_alpha > 0:
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
loss = loss + self.ssim_loss_alpha * ssim_loss
return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss
Expand All @@ -747,12 +748,12 @@ def forward(
loss = loss + self.dur_loss_alpha * dur_loss
return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss

if self.pitch_loss_alpha > 0:
if hasattr(self, "pitch_loss") and self.pitch_loss_alpha > 0:
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
loss = loss + self.pitch_loss_alpha * pitch_loss
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss

if self.aligner_loss_alpha > 0:
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
loss = loss + self.aligner_loss_alpha * aligner_loss
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
Expand Down

0 comments on commit 570d597

Please sign in to comment.