Skip to content

Commit

Permalink
run validation while training
Browse files Browse the repository at this point in the history
  • Loading branch information
ming024 committed Jul 6, 2020
1 parent 8f4b946 commit 3ca60d4
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 65 deletions.
3 changes: 2 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def reprocess(self, batch, cut_list):
f0s = [batch[ind]["f0"] for ind in cut_list]
energies = [batch[ind]["energy"] for ind in cut_list]

length_text = np.array([])
length_text = np.array(list())
for text in texts:
length_text = np.append(length_text, text.shape[0])

Expand Down Expand Up @@ -87,6 +87,7 @@ def reprocess(self, batch, cut_list):
"energy": energies,
"mel_pos": mel_pos,
"src_pos": src_pos,
"src_len": length_text,
"mel_len": length_mel}

return out
Expand Down
97 changes: 53 additions & 44 deletions eval.py → evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,16 @@ def get_FastSpeech2(num):
model.eval()
return model

def main(args):
def evaluate(model, step, wave_glow=None):
torch.manual_seed(0)

# Get dataset
dataset = Dataset("val.txt", sort=False)
loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, )

# Get model
model = get_FastSpeech2(args.step).to(device)
print("Model Has Been Defined")
num_param = utils.get_param_num(model)
print('Number of FastSpeech2 Parameters:', num_param)

# Init directories
if not os.path.exists(hp.logger_path):
os.makedirs(hp.logger_path)
if not os.path.exists(hp.eval_path):
os.makedirs(hp.eval_path)

# Get loss function
Loss = FastSpeech2Loss().to(device)
print("Loss Function Defined.")

# Load vocoder
wave_glow = utils.get_WaveGlow()

# Evaluation
d_l = []
f_l = []
Expand All @@ -71,6 +55,7 @@ def main(args):
energy = torch.from_numpy(data_of_batch["energy"]).float().to(device)
mel_pos = torch.from_numpy(data_of_batch["mel_pos"]).long().to(device)
src_pos = torch.from_numpy(data_of_batch["src_pos"]).long().to(device)
src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device)
mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device)
max_len = max(data_of_batch["mel_len"]).astype(np.int16)

Expand All @@ -81,29 +66,36 @@ def main(args):

# Cal Loss
mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, mel_len)
duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, src_len, mel_len)

d_l.append(d_loss.item())
f_l.append(f_loss.item())
e_l.append(e_loss.item())
mel_l.append(mel_loss.item())
mel_p_l.append(mel_postnet_loss.item())

for k in range(len(mel_target)):
length = mel_len[k]

mel_target_torch = mel_target[k:k+1, :length].transpose(1, 2).detach()
mel_target_ = mel_target[k, :length].cpu().transpose(0, 1).detach()
waveglow.inference.inference(mel_target_torch, wave_glow, os.path.join(hp.eval_path, 'ground-truth_{}_waveglow.wav'.format(idx)))

mel_postnet_torch = mel_postnet_output[k:k+1, :length].transpose(1, 2).detach()
mel_postnet = mel_postnet_output[k, :length].cpu().transpose(0, 1).detach()
waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(hp.eval_path, 'eval_{}_waveglow.wav'.format(idx)))

utils.plot_data([(mel_postnet.numpy(), None, None), (mel_target_.numpy(), None, None)],
['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(idx)))
idx += 1


if wave_glow is not None:
# Run vocoding and plotting spectrogram only when the vocoder is defined
for k in range(len(mel_target)):
length = mel_len[k]

mel_target_torch = mel_target[k:k+1, :length].transpose(1, 2).detach()
mel_target_ = mel_target[k, :length].cpu().transpose(0, 1).detach()
waveglow.inference.inference(mel_target_torch, wave_glow, os.path.join(hp.eval_path, 'ground-truth_{}_waveglow.wav'.format(idx)))

mel_postnet_torch = mel_postnet_output[k:k+1, :length].transpose(1, 2).detach()
mel_postnet = mel_postnet_output[k, :length].cpu().transpose(0, 1).detach()
waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(hp.eval_path, 'eval_{}_waveglow.wav'.format(idx)))

f0_ = f0[k, :length].detach().cpu().numpy()
energy_ = energy[k, :length].detach().cpu().numpy()
f0_output_ = f0_output[k, :length].detach().cpu().numpy()
energy_output_ = energy_output[k, :length].detach().cpu().numpy()

utils.plot_data([(mel_postnet.numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)],
['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(idx)))
idx += 1

current_step += 1

d_l = sum(d_l) / len(d_l)
Expand All @@ -112,7 +104,7 @@ def main(args):
mel_l = sum(mel_l) / len(mel_l)
mel_p_l = sum(mel_p_l) / len(mel_p_l)

str1 = "FastSpeech2 Step {},".format(args.step)
str1 = "FastSpeech2 Step {},".format(step)
str2 = "Duration Loss: {}".format(d_l)
str3 = "F0 Loss: {}".format(f_l)
str4 = "Energy Loss: {}".format(e_l)
Expand All @@ -126,19 +118,36 @@ def main(args):
print(str5)
print(str6)

with open(os.path.join(hp.logger_path, "eval.txt"), "a") as f_logger:
f_logger.write(str1 + "\n")
f_logger.write(str2 + "\n")
f_logger.write(str3 + "\n")
f_logger.write(str4 + "\n")
f_logger.write(str5 + "\n")
f_logger.write(str6 + "\n")
f_logger.write("\n")
with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
f_log.write(str1 + "\n")
f_log.write(str2 + "\n")
f_log.write(str3 + "\n")
f_log.write(str4 + "\n")
f_log.write(str5 + "\n")
f_log.write(str6 + "\n")
f_log.write("\n")

return d_l, f_l, e_l, mel_l, mel_p_l

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument('--step', type=int, default=30000)
args = parser.parse_args()

# Get model
model = get_FastSpeech2(args.step).to(device)
print("Model Has Been Defined")
num_param = utils.get_param_num(model)
print('Number of FastSpeech2 Parameters:', num_param)

# Load vocoder
wave_glow = utils.get_WaveGlow()

# Init directories
if not os.path.exists(hp.log_path):
os.makedirs(hp.log_path)
if not os.path.exists(hp.eval_path):
os.makedirs(hp.eval_path)

main(args)
evaluate(model, args.step, wave_glow)
3 changes: 2 additions & 1 deletion hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Dataset
dataset = "LJSpeech"
data_path = "./LJSpeech-1.1"
data_path = "/home/ming/Data/Raw/LJSpeech-1.1"
#dataset = "Blizzard2013"
#data_path = "./Blizzard-2013/train/segmented/"

Expand Down Expand Up @@ -87,6 +87,7 @@
# Save, log and synthesis
save_step = 10000
synth_step = 1000
eval_step = 1000
eval_size = 256
log_step = 50
clear_Time = 20
8 changes: 4 additions & 4 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class FastSpeech2Loss(nn.Module):
def __init__(self):
super(FastSpeech2Loss, self).__init__()

def forward(self, d_predicted, d_target, p_predicted, p_target, e_predicted, e_target, mel, mel_postnet, mel_target, mel_length):
def forward(self, d_predicted, d_target, p_predicted, p_target, e_predicted, e_target, mel, mel_postnet, mel_target, src_length, mel_length):
d_target.requires_grad = False
p_target.requires_grad = False
e_target.requires_grad = False
Expand All @@ -33,8 +33,8 @@ def forward(self, d_predicted, d_target, p_predicted, p_target, e_predicted, e_t
mel_loss = mse_loss(mel, mel_target, mel_length)
mel_postnet_loss = mse_loss(mel_postnet, mel_target, mel_length)

d_loss = mae_loss(d_predicted, d_target.float(), mel_length)
p_loss = mae_loss(p_predicted, p_target, length)
e_loss = mae_loss(e_predicted, e_target, length)
d_loss = mae_loss(d_predicted, d_target.float(), src_length)
p_loss = mae_loss(p_predicted, p_target, mel_length)
e_loss = mae_loss(e_predicted, e_target, mel_length)

return mel_loss, mel_postnet_loss, d_loss, p_loss, e_loss
File renamed without changes.
33 changes: 25 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from loss import FastSpeech2Loss
from dataset import Dataset
from optimizer import ScheduledOptim
from evaluate import evaluate
import hparams as hp
import utils
import audio as Audio
Expand Down Expand Up @@ -94,6 +95,7 @@ def main(args):
energy = torch.from_numpy(data_of_batch["energy"]).float().to(device)
mel_pos = torch.from_numpy(data_of_batch["mel_pos"]).long().to(device)
src_pos = torch.from_numpy(data_of_batch["src_pos"]).long().to(device)
src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device)
mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device)
max_len = max(data_of_batch["mel_len"]).astype(np.int16)

Expand All @@ -103,7 +105,7 @@ def main(args):

# Cal Loss
mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, mel_len)
duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, src_len, mel_len)
total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss

# Logger
Expand Down Expand Up @@ -156,12 +158,12 @@ def main(args):
f_log.write(str3 + "\n")
f_log.write("\n")

logger.add_scalar('Loss/total_loss', t_l, current_step)
logger.add_scalar('Loss/mel_loss', m_l, current_step)
logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step)
logger.add_scalar('Loss/duration_loss', d_l, current_step)
logger.add_scalar('Loss/F0_loss', f_l, current_step)
logger.add_scalar('Loss/energy_loss', e_l, current_step)
logger.add_scalars('Loss/total_loss', {'training': t_l}, current_step)
logger.add_scalars('Loss/mel_loss', {'training': m_l}, current_step)
logger.add_scalars('Loss/mel_postnet_loss', {'training': m_p_l}, current_step)
logger.add_scalars('Loss/duration_loss', {'training': d_l}, current_step)
logger.add_scalars('Loss/F0_loss', {'training': f_l}, current_step)
logger.add_scalars('Loss/energy_loss', {'training': e_l}, current_step)

if current_step % hp.save_step == 0:
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
Expand All @@ -188,8 +190,23 @@ def main(args):
energy_output = energy_output[0, :length].detach().cpu().numpy()

utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy)],
['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(synth_path, 'step_{}.png'.format(current_step)))
['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(synth_path, 'step_{}.png'.format(current_step)))

if current_step % hp.eval_step == 0:
model.eval()
with torch.no_grad():
d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step)
t_l = d_l + f_l + e_l + m_l + m_p_l

logger.add_scalars('Loss/total_loss', {'validation': t_l}, current_step)
logger.add_scalars('Loss/mel_loss', {'validation': m_l}, current_step)
logger.add_scalars('Loss/mel_postnet_loss', {'validation': m_p_l}, current_step)
logger.add_scalars('Loss/duration_loss', {'validation': d_l}, current_step)
logger.add_scalars('Loss/F0_loss', {'validation': f_l}, current_step)
logger.add_scalars('Loss/energy_loss', {'validation': e_l}, current_step)

model.train()

end_time = time.perf_counter()
Time = np.append(Time, end_time - start_time)
if len(Time) == hp.clear_Time:
Expand Down
8 changes: 1 addition & 7 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ def process_meta(meta_path):
text.append(t)
return name, text


def get_param_num(model):
num_param = sum(param.numel() for param in model.parameters())
return num_param


def plot_data(data, titles=None, filename=None):
fig, axes = plt.subplots(len(data), 1, squeeze=False)
if titles is None:
Expand Down Expand Up @@ -90,7 +88,7 @@ def add_axis(fig, old_ax, offset=0):
ax2.set_ylabel('Energy', color='darkviolet')
ax2.yaxis.set_label_position('right')
ax2.tick_params(labelsize='x-small', colors='darkviolet', bottom=False, labelbottom=False, left=False, labelleft=False, right=True, labelright=True)

plt.savefig(filename, dpi=200)
plt.clf()

Expand All @@ -103,7 +101,6 @@ def get_mask_from_lengths(lengths, max_len=None):

return mask


def get_WaveGlow():
waveglow_path = hp.waveglow_path
wave_glow = torch.load(waveglow_path)['model']
Expand All @@ -115,7 +112,6 @@ def get_WaveGlow():

return wave_glow


def pad_1D(inputs, PAD=0):

def pad_data(x, length, PAD):
Expand All @@ -129,7 +125,6 @@ def pad_data(x, length, PAD):

return padded


def pad_2D(inputs, maxlen=None):

def pad(x, max_len):
Expand All @@ -151,7 +146,6 @@ def pad(x, max_len):

return output


def pad(input_ele, mel_max_length=None):
if mel_max_length:
max_len = mel_max_length
Expand Down

0 comments on commit 3ca60d4

Please sign in to comment.