Skip to content

Commit

Permalink
modify model architecture, add tensorboard loss traces
Browse files Browse the repository at this point in the history
  • Loading branch information
ming024 committed Jul 4, 2020
1 parent bd4c341 commit 6ba4875
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 38 deletions.
4 changes: 3 additions & 1 deletion audio/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,6 @@ def mel_spectrogram(self, y):
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output)
return mel_output
energy = torch.norm(magnitudes, dim=1)

return mel_output, energy
10 changes: 6 additions & 4 deletions audio/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ def get_mel(filename):
audio_norm = audio / hparams.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = _stft.mel_spectrogram(audio_norm)
melspec, energy = _stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
energy = torch.squeeze(energy, 0)
# melspec = torch.from_numpy(_normalize(melspec.numpy()))

return melspec
return melspec, energy


def get_mel_from_wav(audio):
Expand All @@ -41,10 +42,11 @@ def get_mel_from_wav(audio):
audio_norm = audio / hparams.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = _stft.mel_spectrogram(audio_norm)
melspec, energy = _stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
energy = torch.squeeze(energy, 0)

return melspec
return melspec, energy


def inv_mel_spec(mel, out_filename, griffin_iters=60):
Expand Down
16 changes: 8 additions & 8 deletions data/blizzard2013.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ def process_utterance(in_dir, out_dir, basename):
# Get alignments
textgrid = tgt.io.read_textgrid(tg_path)
phone, duration, start, end = get_alignment(textgrid.get_tier_by_name('phones'))
text = '{'+ ' '.join(phone) + '}'
text = text.replace(' $ ', '} {') # $ represents silent phones
text = '{'+ '}{'.join(phone) + '}' # '{A}{B}{$}{C}', $ represents silent phones
text = text.replace('{$}', ' ') # '{A}{B} {C}'
text = text.replace('}{', ' ') # '{A B} {C}'

if start >= end:
return None

Expand All @@ -90,15 +92,13 @@ def process_utterance(in_dir, out_dir, basename):
f0, _ = pw.dio(wav.astype(np.float64), hp.sampling_rate, frame_period=hp.hop_length/hp.sampling_rate*1000)
f0 = f0[:sum(duration)]

# Compute mel-scale spectrogram
mel_spectrogram = Audio.tools.get_mel_from_wav(torch.FloatTensor(wav)).numpy().astype(np.float32)
mel_spectrogram = mel_spectrogram[:, :sum(duration)]
# Compute mel-scale spectrogram and energy
mel_spectrogram, energy = Audio.tools.get_mel_from_wav(torch.FloatTensor(wav))
mel_spectrogram = mel_spectrogram.numpy().astype(np.float32)[:, :sum(duration)]
energy = energy.numpy().astype(np.float32)[:sum(duration)]
if mel_spectrogram.shape[1] >= hp.max_seq_len:
return None

# Compute energy
energy = np.linalg.norm(mel_spectrogram, axis=0)

# Save alignment
ali_filename = '{}-ali-{}.npy'.format(hp.dataset, basename)
np.save(os.path.join(out_dir, 'alignment', ali_filename), duration, allow_pickle=False)
Expand Down
10 changes: 4 additions & 6 deletions data/ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,13 @@ def process_utterance(in_dir, out_dir, basename):
f0, _ = pw.dio(wav.astype(np.float64), hp.sampling_rate, frame_period=hp.hop_length/hp.sampling_rate*1000)
f0 = f0[:sum(duration)]

# Compute mel-scale spectrogram
mel_spectrogram = Audio.tools.get_mel_from_wav(torch.FloatTensor(wav)).numpy().astype(np.float32)
mel_spectrogram = mel_spectrogram[:, :sum(duration)]
# Compute mel-scale spectrogram and energy
mel_spectrogram, energy = Audio.tools.get_mel_from_wav(torch.FloatTensor(wav))
mel_spectrogram = mel_spectrogram.numpy().astype(np.float32)[:, :sum(duration)]
energy = energy.numpy().astype(np.float32)[:sum(duration)]
if mel_spectrogram.shape[1] >= hp.max_seq_len:
return None

# Compute energy
energy = np.linalg.norm(mel_spectrogram, axis=0)

# Save alignment
ali_filename = '{}-ali-{}.npy'.format(hp.dataset, basename)
np.save(os.path.join(out_dir, 'alignment', ali_filename), duration, allow_pickle=False)
Expand Down
2 changes: 1 addition & 1 deletion hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
checkpoint_path = os.path.join("./ckpt/", dataset)
synth_path = os.path.join("./synth/", dataset)
eval_path = os.path.join("./eval/", dataset)
logger_path = os.path.join("./log/", dataset)
log_path = os.path.join("./log/", dataset)
test_path = "./results"
waveglow_path = "./waveglow/pretrained_model/waveglow_256channels.pt"

Expand Down
4 changes: 2 additions & 2 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def forward(self, x, duration_target=None, pitch_target=None, energy_target=None
pitch_embedding = self.pitch_embedding(torch.bucketize(pitch_target, self.pitch_bins))
else:
pitch_embedding = self.pitch_embedding(torch.bucketize(pitch_prediction, self.pitch_bins))
x = x + pitch_embedding

energy_prediction = self.energy_predictor(x)
if energy_target is not None:
energy_embedding = self.energy_embedding(torch.bucketize(energy_target, self.energy_bins))
else:
energy_embedding = self.energy_embedding(torch.bucketize(energy_prediction, self.energy_bins))
x = x + energy_embedding

x = x + pitch_embedding + energy_embedding

return x, duration_prediction, pitch_prediction, energy_prediction, mel_pos

Expand Down
Binary file removed synth/LJSpeech/step_300000_0.png
Binary file not shown.
39 changes: 24 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import argparse
Expand Down Expand Up @@ -56,9 +57,10 @@ def main(args):
wave_glow = utils.get_WaveGlow()

# Init logger
logger_path = hp.logger_path
if not os.path.exists(logger_path):
os.makedirs(logger_path)
log_path = hp.log_path
if not os.path.exists(log_path):
os.makedirs(log_path)
logger = SummaryWriter(log_path)

# Init synthesis directory
synth_path = hp.synth_path
Expand Down Expand Up @@ -111,17 +113,17 @@ def main(args):
d_l = d_loss.item()
f_l = f_loss.item()
e_l = e_loss.item()
with open(os.path.join(logger_path, "total_loss.txt"), "a") as f_total_loss:
with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss:
f_total_loss.write(str(t_l)+"\n")
with open(os.path.join(logger_path, "mel_loss.txt"), "a") as f_mel_loss:
with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss:
f_mel_loss.write(str(m_l)+"\n")
with open(os.path.join(logger_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
f_mel_postnet_loss.write(str(m_p_l)+"\n")
with open(os.path.join(logger_path, "duration_loss.txt"), "a") as f_d_loss:
with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss:
f_d_loss.write(str(d_l)+"\n")
with open(os.path.join(logger_path, "f0_loss.txt"), "a") as f_f_loss:
with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss:
f_f_loss.write(str(f_l)+"\n")
with open(os.path.join(logger_path, "energy_loss.txt"), "a") as f_e_loss:
with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss:
f_e_loss.write(str(e_l)+"\n")

# Backward
Expand All @@ -148,12 +150,19 @@ def main(args):
print(str2)
print(str3)

with open(os.path.join(logger_path, "logger.txt"), "a") as f_logger:
f_logger.write(str1 + "\n")
f_logger.write(str2 + "\n")
f_logger.write(str3 + "\n")
f_logger.write("\n")

with open(os.path.join(log_path, "log.txt"), "a") as f_log:
f_log.write(str1 + "\n")
f_log.write(str2 + "\n")
f_log.write(str3 + "\n")
f_log.write("\n")

logger.add_scalar('Loss/total_loss', t_l, current_step)
logger.add_scalar('Loss/mel_loss', m_l, current_step)
logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step)
logger.add_scalar('Loss/duration_loss', d_l, current_step)
logger.add_scalar('Loss/F0_loss', f_l, current_step)
logger.add_scalar('Loss/energy_loss', e_l, current_step)

if current_step % hp.save_step == 0:
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
)}, os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step)))
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def plot_data(data, titles=None, figsize=None, filename=None):
axes[i][0].imshow(spectrogram, aspect='auto', origin='bottom', interpolation='none')
axes[i][0].title.set_text(titles[i])
plt.savefig(filename)

plt.clf()

def get_mask_from_lengths(lengths, max_len=None):
if max_len == None:
Expand Down

0 comments on commit 6ba4875

Please sign in to comment.