Skip to content

Commit

Permalink
masked loss computation, advanced spectrogram plot
Browse files Browse the repository at this point in the history
  • Loading branch information
ming024 committed Jul 5, 2020
1 parent 6ba4875 commit 8f4b946
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 27 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ There are several differences between my implementation and the paper.

## TODO
- Try difference weights for the loss terms.
- My loss computation does not mask out the paddings.
- Evaluate the quality of the synthesized audio over the validation set.
- Find the difference between the F0 & energy predicted by the variance predictors and the F0 & energy of the synthesized utterance measured by PyWorld Vocoder.
- Implement FastSpeech 2s.
Expand Down
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main(args):

# Cal Loss
mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target)
duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, mel_len)

d_l.append(d_loss.item())
f_l.append(f_loss.item())
Expand Down
4 changes: 2 additions & 2 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
### for LJSpeech ###
f0_min = 71.0
f0_max = 795.8
energy_min = 17.76
energy_max = 91.42
energy_min = 0.0
energy_max = 315.0
### for Blizzard2013 ###
#f0_min = 71.0
#f0_max = 786.7
Expand Down
32 changes: 23 additions & 9 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,39 @@
import torch.nn as nn
import hparams as hp

def mse_loss(prediction, target, length):
batch_size = target.shape[0]
loss = 0
for p, t, l in zip(prediction, target, length):
loss += torch.mean((prediction[:l]-target[:l])**2)
loss /= batch_size
return loss

def mae_loss(prediction, target, length):
batch_size = target.shape[0]
loss = 0
for p, t, l in zip(prediction, target, length):
loss += torch.mean(torch.abs(prediction[:l]-target[:l]))
loss /= batch_size
return loss

class FastSpeech2Loss(nn.Module):
""" FastSpeech2 Loss """

def __init__(self, reduction='mean'):
def __init__(self):
super(FastSpeech2Loss, self).__init__()
self.mse_loss = nn.MSELoss(reduction=reduction)
self.mae_loss = nn.L1Loss(reduction=reduction)

def forward(self, d_predicted, d_target, p_predicted, p_target, e_predicted, e_target, mel, mel_postnet, mel_target):
def forward(self, d_predicted, d_target, p_predicted, p_target, e_predicted, e_target, mel, mel_postnet, mel_target, mel_length):
d_target.requires_grad = False
p_target.requires_grad = False
e_target.requires_grad = False
mel_target.requires_grad = False

mel_loss = self.mse_loss(mel, mel_target)
mel_postnet_loss = self.mse_loss(mel_postnet, mel_target)
mel_loss = mse_loss(mel, mel_target, mel_length)
mel_postnet_loss = mse_loss(mel_postnet, mel_target, mel_length)

d_loss = self.mae_loss(d_predicted, d_target.float())
p_loss = self.mae_loss(p_predicted, p_target)
e_loss = self.mae_loss(e_predicted, e_target)
d_loss = mae_loss(d_predicted, d_target.float(), mel_length)
p_loss = mae_loss(p_predicted, p_target, length)
e_loss = mae_loss(e_predicted, e_target, length)

return mel_loss, mel_postnet_loss, d_loss, p_loss, e_loss
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ numba == 0.48
matplotlib == 3.2.2
unidecode == 1.1.1
inflect == 4.1.0
g2p-en == 2.1.0
10 changes: 6 additions & 4 deletions synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ def synthesize(model, text, sentence, prefix=''):
src_pos = torch.from_numpy(src_pos).to(device).long()

model.to(device)
mel, mel_postnet, d_prediction, p_prediction, e_prediction = model(text, src_pos)
mel, mel_postnet, duration_output, f0_output, energy_output = model(text, src_pos)
model.to('cpu')

mel_torch = mel.transpose(1, 2).detach()
mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
mel = mel[0].cpu().transpose(0, 1).detach()
mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach()
f0_output = f0_output[0].detach().cpu().numpy()
energy_output = energy_output[0].detach().cpu().numpy()

if not os.path.exists(hp.test_path):
os.makedirs(hp.test_path)
Expand All @@ -59,7 +61,7 @@ def synthesize(model, text, sentence, prefix=''):
waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(
hp.test_path, '{}_waveglow_{}.wav'.format(prefix, sentence)))

utils.plot_data([(mel_postnet.numpy(), None, None)], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))
utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))


if __name__ == "__main__":
Expand All @@ -69,8 +71,8 @@ def synthesize(model, text, sentence, prefix=''):
args = parser.parse_args()

sentence = "Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition"
#sentence = "in being comparatively modern."
#sentence = "For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process"
sentence = "in being comparatively modern."
sentence = "For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process"
#sentence = "produced the block books, which were the immediate predecessors of the true printed book,"
#sentence = "the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing."
#sentence = "And it is worth mention in passing that, as an example of fine typography,"
Expand Down
10 changes: 8 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def main(args):

# Cal Loss
mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target)
duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, mel_len)
total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

# Logger
Expand Down Expand Up @@ -181,7 +181,13 @@ def main(args):
waveglow.inference.inference(mel_torch, wave_glow, os.path.join(synth_path, "step_{}_waveglow.wav".format(current_step)))
waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(synth_path, "step_{}_postnet_waveglow.wav".format(current_step)))
waveglow.inference.inference(mel_target_torch, wave_glow, os.path.join(synth_path, "step_{}_ground-truth_waveglow.wav".format(current_step)))
utils.plot_data([(mel_postnet.numpy(), None, None), (mel_target.numpy(), None, None)],

f0 = f0[0, :length].detach().cpu().numpy()
energy = energy[0, :length].detach().cpu().numpy()
f0_output = f0_output[0, :length].detach().cpu().numpy()
energy_output = energy_output[0, :length].detach().cpu().numpy()

utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy)],
['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(synth_path, 'step_{}.png'.format(current_step)))

end_time = time.perf_counter()
Expand Down
39 changes: 31 additions & 8 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,41 @@ def get_param_num(model):
return num_param


def plot_data(data, titles=None, figsize=None, filename=None):
if figsize is None:
figsize = (12, 6*len(data))
_, axes = plt.subplots(len(data), 1, squeeze=False, figsize=figsize)

def plot_data(data, titles=None, filename=None):
fig, axes = plt.subplots(len(data), 1, squeeze=False)
if titles is None:
titles = [None for i in range(len(data))]

def add_axis(fig, old_ax, offset=0):
ax = fig.add_axes(old_ax.get_position(), anchor='W')
ax.set_facecolor("None")
return ax

for i in range(len(data)):
spectrogram, pitch, energy = data[i]
axes[i][0].imshow(spectrogram, aspect='auto', origin='bottom', interpolation='none')
axes[i][0].title.set_text(titles[i])
plt.savefig(filename)
axes[i][0].imshow(spectrogram, origin='lower')
axes[i][0].set_aspect(2.5, adjustable='box')
axes[i][0].set_ylim(0, hp.n_mel_channels)
axes[i][0].set_title(titles[i], fontsize='medium')
axes[i][0].tick_params(labelsize='x-small', left=False, labelleft=False)
axes[i][0].set_anchor('W')

ax1 = add_axis(fig, axes[i][0])
ax1.plot(pitch, color='tomato')
ax1.set_xlim(0, spectrogram.shape[1])
ax1.set_ylim(0, hp.f0_max)
ax1.set_ylabel('F0', color='tomato')
ax1.tick_params(labelsize='x-small', colors='tomato', bottom=False, labelbottom=False)

ax2 = add_axis(fig, axes[i][0], 1.2)
ax2.plot(energy, color='darkviolet')
ax2.set_xlim(0, spectrogram.shape[1])
ax2.set_ylim(hp.energy_min, hp.energy_max)
ax2.set_ylabel('Energy', color='darkviolet')
ax2.yaxis.set_label_position('right')
ax2.tick_params(labelsize='x-small', colors='darkviolet', bottom=False, labelbottom=False, left=False, labelleft=False, right=True, labelright=True)

plt.savefig(filename, dpi=200)
plt.clf()

def get_mask_from_lengths(lengths, max_len=None):
Expand Down

0 comments on commit 8f4b946

Please sign in to comment.