Skip to content

Commit

Permalink
Code re-formate
Browse files Browse the repository at this point in the history
  • Loading branch information
ming024 committed Dec 11, 2020
1 parent 991665e commit e58247c
Show file tree
Hide file tree
Showing 20 changed files with 317 additions and 193 deletions.
1 change: 1 addition & 0 deletions audio/audio_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import librosa.util as librosa_util
import hparams as hp


def window_sumsquare(window, n_frames, hop_length=hp.hop_length, win_length=hp.win_length,
n_fft=hp.filter_length, dtype=np.float32, norm=None):
"""
Expand Down
2 changes: 1 addition & 1 deletion audio/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,5 @@ def mel_spectrogram(self, y):
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output)
energy = torch.norm(magnitudes, dim=1)

return mel_output, energy
45 changes: 28 additions & 17 deletions data/ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
from text import _clean_text
import hparams as hp


def prepare_align(in_dir):
with open(os.path.join(in_dir, 'metadata.csv'), encoding='utf-8') as f:
for line in f:
parts = line.strip().split('|')
basename = parts[0]
text = parts[2]
text = _clean_text(text, hp.text_cleaners)

with open(os.path.join(in_dir, 'wavs', '{}.txt'.format(basename)), 'w') as f1:
f1.write(text)


def build_from_path(in_dir, out_dir):
index = 1
train = list()
Expand All @@ -32,13 +34,13 @@ def build_from_path(in_dir, out_dir):
parts = line.strip().split('|')
basename = parts[0]
text = parts[2]

ret = process_utterance(in_dir, out_dir, basename)
if ret is None:
continue
else:
info, f_max, f_min, e_max, e_min, n = ret

if basename[:5] in ['LJ001', 'LJ002', 'LJ003']:
val.append(info)
else:
Expand All @@ -53,7 +55,7 @@ def build_from_path(in_dir, out_dir):
energy_max = max(energy_max, e_max)
energy_min = min(energy_min, e_min)
n_frames += n

with open(os.path.join(out_dir, 'stat.txt'), 'w', encoding='utf-8') as f:
strs = ['Total time: {} hours'.format(n_frames*hp.hop_length/hp.sampling_rate/3600),
'Total frames: {}'.format(n_frames),
Expand All @@ -64,17 +66,20 @@ def build_from_path(in_dir, out_dir):
for s in strs:
print(s)
f.write(s+'\n')

return [r for r in train if r is not None], [r for r in val if r is not None]


def process_utterance(in_dir, out_dir, basename):
wav_path = os.path.join(in_dir, 'wavs', '{}.wav'.format(basename))
tg_path = os.path.join(out_dir, 'TextGrid', '{}.TextGrid'.format(basename))
tg_path = os.path.join(out_dir, 'TextGrid', '{}.TextGrid'.format(basename))

# Get alignments
textgrid = tgt.io.read_textgrid(tg_path)
phone, duration, start, end = get_alignment(textgrid.get_tier_by_name('phones'))
text = '{'+ '}{'.join(phone) + '}' # '{A}{B}{$}{C}', $ represents silent phones
phone, duration, start, end = get_alignment(
textgrid.get_tier_by_name('phones'))
# '{A}{B}{$}{C}', $ represents silent phones
text = '{' + '}{'.join(phone) + '}'
text = text.replace('{$}', ' ') # '{A}{B} {C}'
text = text.replace('}{', ' ') # '{A B} {C}'

Expand All @@ -84,32 +89,38 @@ def process_utterance(in_dir, out_dir, basename):
# Read and trim wav files
_, wav = read(wav_path)
wav = wav[int(hp.sampling_rate*start):int(hp.sampling_rate*end)].astype(np.float32)

# Compute fundamental frequency
f0, _ = pw.dio(wav.astype(np.float64), hp.sampling_rate, frame_period=hp.hop_length/hp.sampling_rate*1000)
f0, _ = pw.dio(wav.astype(np.float64), hp.sampling_rate,
frame_period=hp.hop_length/hp.sampling_rate*1000)
f0 = f0[:sum(duration)]

# Compute mel-scale spectrogram and energy
mel_spectrogram, energy = Audio.tools.get_mel_from_wav(torch.FloatTensor(wav))
mel_spectrogram = mel_spectrogram.numpy().astype(np.float32)[:, :sum(duration)]
mel_spectrogram, energy = Audio.tools.get_mel_from_wav(
torch.FloatTensor(wav))
mel_spectrogram = mel_spectrogram.numpy().astype(np.float32)[
:, :sum(duration)]
energy = energy.numpy().astype(np.float32)[:sum(duration)]
if mel_spectrogram.shape[1] >= hp.max_seq_len:
return None

# Save alignment
ali_filename = '{}-ali-{}.npy'.format(hp.dataset, basename)
np.save(os.path.join(out_dir, 'alignment', ali_filename), duration, allow_pickle=False)
np.save(os.path.join(out_dir, 'alignment', ali_filename),
duration, allow_pickle=False)

# Save fundamental prequency
f0_filename = '{}-f0-{}.npy'.format(hp.dataset, basename)
np.save(os.path.join(out_dir, 'f0', f0_filename), f0, allow_pickle=False)

# Save energy
energy_filename = '{}-energy-{}.npy'.format(hp.dataset, basename)
np.save(os.path.join(out_dir, 'energy', energy_filename), energy, allow_pickle=False)
np.save(os.path.join(out_dir, 'energy', energy_filename),
energy, allow_pickle=False)

# Save spectrogram
mel_filename = '{}-mel-{}.npy'.format(hp.dataset, basename)
np.save(os.path.join(out_dir, 'mel', mel_filename), mel_spectrogram.T, allow_pickle=False)

np.save(os.path.join(out_dir, 'mel', mel_filename),
mel_spectrogram.T, allow_pickle=False)

return '|'.join([basename, text]), max(f0), min([f for f in f0 if f != 0]), max(energy), min(energy), mel_spectrogram.shape[1]
20 changes: 12 additions & 8 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

class Dataset(Dataset):
def __init__(self, filename="train.txt", sort=True):
self.basename, self.text = process_meta(os.path.join(hparams.preprocessed_path, filename))
self.basename, self.text = process_meta(
os.path.join(hparams.preprocessed_path, filename))
self.sort = sort

def __len__(self):
Expand All @@ -36,7 +37,7 @@ def __getitem__(self, idx):
energy_path = os.path.join(
hparams.preprocessed_path, "energy", "{}-energy-{}.npy".format(hparams.dataset, basename))
energy = np.load(energy_path)

sample = {"id": basename,
"text": phone,
"mel_target": mel_target,
Expand All @@ -63,7 +64,7 @@ def reprocess(self, batch, cut_list):
length_mel = np.array(list())
for mel in mel_targets:
length_mel = np.append(length_mel, mel.shape[0])

texts = pad_1D(texts)
Ds = pad_1D(Ds)
mel_targets = pad_2D(mel_targets)
Expand All @@ -80,7 +81,7 @@ def reprocess(self, batch, cut_list):
"energy": energies,
"src_len": length_text,
"mel_len": length_mel}

return out

def collate_fn(self, batch):
Expand All @@ -92,21 +93,24 @@ def collate_fn(self, batch):
cut_list = list()
for i in range(real_batchsize):
if self.sort:
cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize])
cut_list.append(
index_arr[i*real_batchsize:(i+1)*real_batchsize])
else:
cut_list.append(np.arange(i*real_batchsize, (i+1)*real_batchsize))

cut_list.append(
np.arange(i*real_batchsize, (i+1)*real_batchsize))

output = list()
for i in range(real_batchsize):
output.append(self.reprocess(batch, cut_list[i]))

return output


if __name__ == "__main__":
# Test
dataset = Dataset('val.txt')
training_loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=0)
drop_last=True, num_workers=0)
total_step = hparams.epochs * len(training_loader) * hparams.batch_size

cnt = 0
Expand Down
102 changes: 61 additions & 41 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,25 @@

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_FastSpeech2(num):
checkpoint_path = os.path.join(hp.checkpoint_path, "checkpoint_{}.pth.tar".format(num))
checkpoint_path = os.path.join(
hp.checkpoint_path, "checkpoint_{}.pth.tar".format(num))
model = nn.DataParallel(FastSpeech2())
model.load_state_dict(torch.load(checkpoint_path)['model'])
model.requires_grad = False
model.eval()
return model


def evaluate(model, step, vocoder=None):
torch.manual_seed(0)

# Get dataset
dataset = Dataset("val.txt", sort=False)
loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, )

loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False,
collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, )

# Get loss function
Loss = FastSpeech2Loss().to(device)

Expand All @@ -49,70 +53,85 @@ def evaluate(model, step, vocoder=None):
# Get Data
id_ = data_of_batch["id"]
text = torch.from_numpy(data_of_batch["text"]).long().to(device)
mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device)
mel_target = torch.from_numpy(
data_of_batch["mel_target"]).float().to(device)
D = torch.from_numpy(data_of_batch["D"]).int().to(device)
log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device)
f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
energy = torch.from_numpy(data_of_batch["energy"]).float().to(device)
src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device)
mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device)
energy = torch.from_numpy(
data_of_batch["energy"]).float().to(device)
src_len = torch.from_numpy(
data_of_batch["src_len"]).long().to(device)
mel_len = torch.from_numpy(
data_of_batch["mel_len"]).long().to(device)
max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

with torch.no_grad():
# Forward
mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, out_mel_len = model(
text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len)
text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len)

# Cal Loss
mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask)
log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask)

d_l.append(d_loss.item())
f_l.append(f_loss.item())
e_l.append(e_loss.item())
mel_l.append(mel_loss.item())
mel_p_l.append(mel_postnet_loss.item())

if vocoder is not None:
# Run vocoding and plotting spectrogram only when the vocoder is defined
for k in range(len(mel_target)):
basename = id_[k]
gt_length = mel_len[k]
out_length = out_mel_len[k]

mel_target_torch = mel_target[k:k+1, :gt_length].transpose(1, 2).detach()
mel_target_ = mel_target[k, :gt_length].cpu().transpose(0, 1).detach()

mel_postnet_torch = mel_postnet_output[k:k+1, :out_length].transpose(1, 2).detach()
mel_postnet = mel_postnet_output[k, :out_length].cpu().transpose(0, 1).detach()


mel_target_torch = mel_target[k:k+1,
:gt_length].transpose(1, 2).detach()
mel_target_ = mel_target[k, :gt_length].cpu(
).transpose(0, 1).detach()

mel_postnet_torch = mel_postnet_output[k:k +
1, :out_length].transpose(1, 2).detach()
mel_postnet = mel_postnet_output[k, :out_length].cpu(
).transpose(0, 1).detach()

if hp.vocoder == 'melgan':
utils.melgan_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
utils.melgan_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
utils.melgan_infer(mel_target_torch, vocoder, os.path.join(
hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
utils.melgan_infer(mel_postnet_torch, vocoder, os.path.join(
hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))
elif hp.vocoder == 'waveglow':
utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
utils.waveglow_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))

np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(basename)), mel_postnet.numpy())

utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(
hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder)))
utils.waveglow_infer(mel_postnet_torch, vocoder, os.path.join(
hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder)))

np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(
basename)), mel_postnet.numpy())

f0_ = f0[k, :gt_length].detach().cpu().numpy()
energy_ = energy[k, :gt_length].detach().cpu().numpy()
f0_output_ = f0_output[k, :out_length].detach().cpu().numpy()
energy_output_ = energy_output[k, :out_length].detach().cpu().numpy()

utils.plot_data([(mel_postnet.numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)],
['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename)))
f0_output_ = f0_output[k,
:out_length].detach().cpu().numpy()
energy_output_ = energy_output[k, :out_length].detach(
).cpu().numpy()

utils.plot_data([(mel_postnet.numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)],
['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename)))
idx += 1
current_step += 1

current_step += 1

d_l = sum(d_l) / len(d_l)
f_l = sum(f_l) / len(f_l)
e_l = sum(e_l) / len(e_l)
mel_l = sum(mel_l) / len(mel_l)
mel_p_l = sum(mel_p_l) / len(mel_p_l)
mel_p_l = sum(mel_p_l) / len(mel_p_l)

str1 = "FastSpeech2 Step {},".format(step)
str2 = "Duration Loss: {}".format(d_l)
str3 = "F0 Loss: {}".format(f_l)
Expand All @@ -138,28 +157,29 @@ def evaluate(model, step, vocoder=None):

return d_l, f_l, e_l, mel_l, mel_p_l


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument('--step', type=int, default=30000)
args = parser.parse_args()

# Get model
model = get_FastSpeech2(args.step).to(device)
print("Model Has Been Defined")
num_param = utils.get_param_num(model)
print('Number of FastSpeech2 Parameters:', num_param)

# Load vocoder
if hp.vocoder == 'melgan':
vocoder = utils.get_melgan()
elif hp.vocoder == 'waveglow':
vocoder = utils.get_waveglow()

# Init directories
if not os.path.exists(hp.log_path):
os.makedirs(hp.log_path)
if not os.path.exists(hp.eval_path):
os.makedirs(hp.eval_path)

evaluate(model, args.step, vocoder)
Loading

0 comments on commit e58247c

Please sign in to comment.