Skip to content

Commit

Permalink
Config for training
Browse files Browse the repository at this point in the history
  • Loading branch information
osoblanco committed Dec 24, 2021
1 parent df4c7c7 commit d3d8062
Show file tree
Hide file tree
Showing 9 changed files with 8,348 additions and 13,011 deletions.
2 changes: 1 addition & 1 deletion config/LJSpeech/preprocess.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
dataset: "LJSpeech"

path:
corpus_path: "/home/ming/Data/LJSpeech-1.1"
corpus_path: "./data/LJSpeech-1.1"
lexicon_path: "lexicon/librispeech-lexicon.txt"
raw_path: "./raw_data/LJSpeech"
preprocessed_path: "./preprocessed_data/LJSpeech"
Expand Down
4 changes: 2 additions & 2 deletions config/LJSpeech/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ optimizer:
eps: 0.000000001
weight_decay: 0.0
grad_clip_thresh: 1.0
grad_acc_step: 1
grad_acc_step: 10
warm_up_step: 4000
anneal_steps: [300000, 400000, 500000]
anneal_rate: 0.3
step:
total_step: 900000
log_step: 100
synth_step: 1000
synth_step: 100
val_step: 1000
save_step: 100000
2 changes: 1 addition & 1 deletion config/LJSpeech_paper/preprocess.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
dataset: "LJSpeech_paper"

path:
corpus_path: "/home/ming/Data/LJSpeech-1.1"
corpus_path: "./data/LJSpeech-1.1"
lexicon_path: "lexicon/librispeech-lexicon.txt"
raw_path: "./raw_data/LJSpeech"
preprocessed_path: "./preprocessed_data/LJSpeech_paper"
Expand Down
2 changes: 1 addition & 1 deletion preprocessed_data/LJSpeech/stats.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"pitch": [-2.917079304729967, 11.391254536985784, 207.6309860026605, 46.77559025098988], "energy": [-1.431044578552246, 8.184337615966797, 37.32621679053821, 26.044180782835863]}
{"pitch": [-2.879620383246412, 10.701647317596422, 207.4988362093985, 46.75837075025294], "energy": [-1.4311870336532593, 8.16907787322998, 37.39676954370247, 26.08419376725949]}
20,295 changes: 7,815 additions & 12,480 deletions preprocessed_data/LJSpeech/train.txt

Large diffs are not rendered by default.

1,006 changes: 503 additions & 503 deletions preprocessed_data/LJSpeech/val.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def build_from_path(self):
speakers = {}
for i, speaker in enumerate(tqdm(os.listdir(self.in_dir))):
speakers[speaker] = i
for wav_name in os.listdir(os.path.join(self.in_dir, speaker)):
for wav_name in tqdm(os.listdir(os.path.join(self.in_dir, speaker))):
if ".wav" not in wav_name:
continue

Expand Down
46 changes: 24 additions & 22 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import yaml
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils.model import get_model, get_vocoder, get_param_num
Expand All @@ -21,15 +21,18 @@
import PIL

import matplotlib.pyplot as plt
import plotly
import plotly.plotly as py
import plotly.tools as tls
# import plotly
# import plotly.plotly as py
# import plotly.tools as tls

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def fig_to_img(fig):
return PIL.Image.frombytes('RGB',fig.canvas.get_width_height(),fig.canvas.tostring_rgb())

import io
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
return PIL.Image.open(buf)

def main(args, configs):
print("Prepare training ...")
Expand Down Expand Up @@ -67,8 +70,8 @@ def main(args, configs):
val_log_path = os.path.join(train_config["path"]["log_path"], "val")
os.makedirs(train_log_path, exist_ok=True)
os.makedirs(val_log_path, exist_ok=True)
train_logger = SummaryWriter(train_log_path)
val_logger = SummaryWriter(val_log_path)
train_logger = None #SummaryWriter(train_log_path)
val_logger = None #SummaryWriter(val_log_path)

# Training
step = args.restore_step + 1
Expand All @@ -86,7 +89,7 @@ def main(args, configs):
outer_bar.update()


aim_run = aim.Run(experiemnt = "FS2")
aim_run = aim.Run(experiment = "FS2")
aim_run["train_config"] = train_config
aim_run["preprocess_config"] = preprocess_config

Expand All @@ -107,7 +110,6 @@ def main(args, configs):
total_loss = total_loss / grad_acc_step



total_loss.backward()
if step % grad_acc_step == 0:
# Clipping gradients to avoid gradient explosion
Expand All @@ -121,11 +123,11 @@ def main(args, configs):

total_loss,mel_loss, postnet_mel_loss,pitch_loss,energy_loss,duration_loss = losses

aim.track(total_loss.item() , name = "Loss", context = {'type':'total_loss'})
aim.track(mel_loss.item() , name = "Loss", context = {'type':'mel_loss'})
aim.track(postnet_mel_loss.item() , name = "Loss", context = {'type':'postnet_mel_loss'})
aim.track(energy_loss.item() , name = "Loss", context = {'type':'pitch_loss'})
aim.track(duration_loss.item() , name = "Loss", context = {'type':'duration_loss'})
aim_run.track(total_loss.item() , name = "Loss", context = {'type':'total_loss'})
aim_run.track(mel_loss.item() , name = "Loss", context = {'type':'mel_loss'})
aim_run.track(postnet_mel_loss.item() , name = "Loss", context = {'type':'postnet_mel_loss'})
aim_run.track(energy_loss.item() , name = "Loss", context = {'type':'pitch_loss'})
aim_run.track(duration_loss.item() , name = "Loss", context = {'type':'duration_loss'})

losses = [l.item() for l in losses]

Expand All @@ -134,14 +136,14 @@ def main(args, configs):
*losses
)

aim.track(aim.Text(message1 + message2 + "\n"), name = 'log_out')
aim_run.track(aim.Text(message1 + message2 + "\n"), name = 'log_out')

with open(os.path.join(train_log_path, "log.txt"), "a") as f:
f.write(message1 + message2 + "\n")

outer_bar.write(message1 + message2)

log(train_logger, step, losses=losses)
# log(train_logger, step, losses=losses)

if step % synth_step == 0:
fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
Expand All @@ -152,13 +154,13 @@ def main(args, configs):
preprocess_config,
)

aim.track(aim.Audio(wav_reconstruction, format='wav'), name = 'waves', context = {'type':'wav_reconstruction'})
aim.track(aim.Audio(wav_prediction, format='wav'), name = 'waves', context = {'type':'wav_prediction'})
aim_run.track(aim.Audio(wav_reconstruction, format='wav'), name = 'waves', context = {'type':'wav_reconstruction'})
aim_run.track(aim.Audio(wav_prediction, format='wav'), name = 'waves', context = {'type':'wav_prediction'})

plotly_fig = tls.mpl_to_plotly(fig)
# plotly_fig = tls.mpl_to_plotly(fig)

aim.track(aim.Image(fig_to_img(fig)), name = 'Sepctrograms', context = {'type':'MEL'})
aim.track(aim.Figure(plotly_fig), name = 'Sepctrograms', context = {'type':'MEL Interactive'})
aim_run.track(aim.Image(fig_to_img(fig)), name = 'Sepctrograms', context = {'type':'MEL'})
# aim.track(aim.Figure(plotly_fig), name = 'Sepctrograms', context = {'type':'MEL Interactive'})

# log(
# train_logger,
Expand Down
Empty file added utils/__init__.py
Empty file.

0 comments on commit d3d8062

Please sign in to comment.