Skip to content

Commit

Permalink
add MelGAN vocoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ming024 committed Jul 11, 2020
1 parent e3bf52c commit b8f5d02
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 291 deletions.
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This repository contains only FastSpeech 2 but FastSpeech 2s so far. I will upda
# Audio Samples
Audio samples generated by this implementation can be found [here](https://ming024.github.io/FastSpeech2/).
- The model used to generate these samples is trained for 300k steps on [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset.
- Audio samples are converted from mel-spectrogram to raw waveform via [NVIDIA's pretrained WaveGlow](https://github.com/NVIDIA/waveglow).
- Audio samples are converted from mel-spectrogram to raw waveform via [NVIDIA's pretrained WaveGlow](https://github.com/NVIDIA/waveglow) and [seungwonpark's pretrained MelGAN](https://github.com/seungwonpark/melgan).

# Quickstart

Expand All @@ -27,7 +27,7 @@ Since PyTorch 1.6 is still unstable, I suggest that Python virtual environment s

## Synthesis

You have to download [NVIDIA's pretrained WaveGlow](https://github.com/NVIDIA/waveglow) and put the checkpoint in the ``waveglow/pretrained_model/`` directory, and download our [FastSpeech2 pretrained model](https://drive.google.com/file/d/1jXNDPMt1ybTN97_MztoTFyrPIthoQuSO/view?usp=sharing) then put it in the ``ckpt/LJSpeech/`` directory.
You have to download our [FastSpeech2 pretrained model](https://drive.google.com/file/d/1jXNDPMt1ybTN97_MztoTFyrPIthoQuSO/view?usp=sharing) and then put it in the ``ckpt/LJSpeech/`` directory.

Your can run
```
Expand All @@ -36,7 +36,7 @@ python3 synthesis.py --step 300000
to generate any utterances you wish to. The generated utterances will be put in the ``results/`` directory.

Here is a generated spectrogram of the sentence "Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition"
![](./synth/LJSpeech/step_300000_0.png)
![](./synth/LJSpeech/step_300000.png)

# Training

Expand All @@ -49,7 +49,7 @@ After downloading the dataset, extract the compressed files, you have to modify

## Preprocessing

As described in the paper, [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/)(MFA) is used to obtain the alignment between utterance and phoneme sequence. Alignments for the LJSpeech dataset is provided [here](https://drive.google.com/file/d/1ukb8o-SnqhXCxq7drI3zye3tZdrGvQDA/view?usp=sharing). You have to put the ``TextGrid.zip`` file in your ``hp.preprocessed_path/`` and extract the files before you continue.
As described in the paper, [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/)(MFA) is used to obtain the alignments between the utterances and the phoneme sequences. Alignments for the LJSpeech dataset is provided [here](https://drive.google.com/file/d/1ukb8o-SnqhXCxq7drI3zye3tZdrGvQDA/view?usp=sharing). You have to put the ``TextGrid.zip`` file in your ``hp.preprocessed_path/`` and extract the files before you continue.

Then run the preprocessing sctipt by
```
Expand Down Expand Up @@ -82,7 +82,7 @@ Remember to run the preprocessing sctipt.
python3 preprocess.py
```

After preprocessing, you will get a ``stat.txt`` file in your ``hp.preprocessed_path/``, recording the maximum and minimum values of the fundamental frequency and energy values in the entire corpus. You have to modify the f0 and energy parameters in the ``hparams.py`` according to your ``stat.txt`` file.
After preprocessing, you will get a ``stat.txt`` file in your ``hp.preprocessed_path/``, recording the maximum and minimum values of the fundamental frequency and energy values in the entire corpus. You have to modify the f0 and energy parameters in the ``hparams.py`` according to the content of ``stat.txt``.

## Training

Expand Down Expand Up @@ -116,6 +116,7 @@ There are several differences between my implementation and the paper.
- Following [xcmyz's implementation](https://github.com/xcmyz/FastSpeech), I use an additional Tacotron-2-styled postnet after the FastSpeech decoder, which is not used in the original paper.
- The [transformer paper](https://arxiv.org/abs/1706.03762) suggests to use dropout after the input and positional embedding. I haven't try it yet.
- The paper suggest to use L1 loss for mel loss and L2 loss for variance predictor losses. But I find it easier to train the model with L2 mel loss and L1 variance adaptor losses, for unknown reason.
- The paper suggests that the duration is predicted in logrithmic domain, while in my implementation the duration prediction and loss is computed in linear domain.
- I use gradient clipping and weigth decay in the training.

## TODO
Expand All @@ -127,5 +128,8 @@ There are several differences between my implementation and the paper.
# References
- [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558), Y. Ren, *et al*.
- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263), Y. Ren, *et al*.
- [xcmyz's FastSpeech implementation](https://github.com/xcmyz/FastSpeech)
- [NVIDIA's WaveGlow implementation](https://github.com/NVIDIA/waveglow)
- [xcmyz's PyTorch FastSpeech implementation](https://github.com/xcmyz/FastSpeech)
- [rishikksh20's PyTorch FastSpeech2 implementation](https://github.com/rishikksh20/FastSpeech2)
- [TensorSpeech's TensorFLow FastSpeech2 implementation](https://github.com/TensorSpeech/TensorflowTTS)
- [NVIDIA's PyTorch WaveGlow implementation](https://github.com/NVIDIA/waveglow)
- [seungwonpark's PYTorch MelGAN implementation](https://github.com/seungwonpark/melgan)
22 changes: 15 additions & 7 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import hparams as hp
import utils
import audio as Audio
import waveglow

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand All @@ -27,7 +26,7 @@ def get_FastSpeech2(num):
model.eval()
return model

def evaluate(model, step, wave_glow=None):
def evaluate(model, step, vocoder=None):
torch.manual_seed(0)

# Get dataset
Expand Down Expand Up @@ -74,18 +73,23 @@ def evaluate(model, step, wave_glow=None):
mel_l.append(mel_loss.item())
mel_p_l.append(mel_postnet_loss.item())

if wave_glow is not None:
if vocoder is not None:
# Run vocoding and plotting spectrogram only when the vocoder is defined
for k in range(len(mel_target)):
length = mel_len[k]

mel_target_torch = mel_target[k:k+1, :length].transpose(1, 2).detach()
mel_target_ = mel_target[k, :length].cpu().transpose(0, 1).detach()
waveglow.inference.inference(mel_target_torch, wave_glow, os.path.join(hp.eval_path, 'ground-truth_{}_waveglow.wav'.format(idx)))

mel_postnet_torch = mel_postnet_output[k:k+1, :length].transpose(1, 2).detach()
mel_postnet = mel_postnet_output[k, :length].cpu().transpose(0, 1).detach()
waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(hp.eval_path, 'eval_{}_waveglow.wav'.format(idx)))

if hp.vocoder == 'melgan':
utils.melgan_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(idx, hp.vocoder)))
utils.melgan_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(idx, hp.vocoder)))
elif hp.vocoder == 'waveglow':
utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(idx, hp.vocoder)))
utils.waveglow_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(idx, hp.vocoder)))

f0_ = f0[k, :length].detach().cpu().numpy()
energy_ = energy[k, :length].detach().cpu().numpy()
Expand Down Expand Up @@ -142,12 +146,16 @@ def evaluate(model, step, wave_glow=None):
print('Number of FastSpeech2 Parameters:', num_param)

# Load vocoder
wave_glow = utils.get_WaveGlow()
if hp.vocoder == 'melgan':
vocoder = utils.get_melgan()
elif hp.vocoder == 'waveglow':
vocoder = utils.get_waveglow()
vocoder.to(device)

# Init directories
if not os.path.exists(hp.log_path):
os.makedirs(hp.log_path)
if not os.path.exists(hp.eval_path):
os.makedirs(hp.eval_path)

evaluate(model, args.step, wave_glow)
evaluate(model, args.step, vocoder)
4 changes: 4 additions & 0 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@
weight_decay = 1e-6


# Vocoder
vocoder = 'melgan' # 'waveglow' or 'melgan'


# Save, log and synthesis
save_step = 10000
synth_step = 1000
Expand Down
Binary file added results/step_300000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 18 additions & 13 deletions synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import hparams as hp
import utils
import audio as Audio
import waveglow

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down Expand Up @@ -57,10 +56,16 @@ def synthesize(model, text, sentence, prefix=''):
os.makedirs(hp.test_path)

Audio.tools.inv_mel_spec(mel_postnet, os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence)))
wave_glow = utils.get_WaveGlow()
waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(
hp.test_path, '{}_waveglow_{}.wav'.format(prefix, sentence)))

if hp.vocoder == 'melgan':
melgan = utils.get_melgan()
melgan.to(device)
utils.melgan_infer(mel_postnet_torch, melgan, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))
elif hp.vocoder == 'waveglow':
waveglow = utils.get_waveglow()
waveglow.to(device)
utils.waveglow_infer(mel_postnet_torch, waveglow, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))

utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))


Expand All @@ -72,15 +77,15 @@ def synthesize(model, text, sentence, prefix=''):

sentence = "Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition"
sentence = "in being comparatively modern."
sentence = "For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process"
sentence = "produced the block books, which were the immediate predecessors of the true printed book,"
sentence = "the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing."
sentence = "And it is worth mention in passing that, as an example of fine typography,"
sentence = "the earliest book printed with movable types, the Gutenberg, or \"forty-two line Bible\" of about 1455,"
sentence = "has never been surpassed."
sentence = "Printing, then, for our purpose, may be considered as the art of making books by means of movable types."
sentence = "Now, as all books not primarily intended as picture-books consist principally of types composed to form letterpress,"

#sentence = "For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process"
#sentence = "produced the block books, which were the immediate predecessors of the true printed book,"
#sentence = "the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing."
#sentence = "And it is worth mention in passing that, as an example of fine typography,"
#sentence = "the earliest book printed with movable types, the Gutenberg, or \"forty-two line Bible\" of about 1455,"
#sentence = "has never been surpassed."
#sentence = "Printing, then, for our purpose, may be considered as the art of making books by means of movable types."
#sentence = "Now, as all books not primarily intended as picture-books consist principally of types composed to form letterpress,"
sentence = "The nation's tourism minister has also encouraged Australian's to take their holidays within the country this year."
text = preprocess(sentence)
model = get_FastSpeech2(args.step)
synthesize(model, text, sentence, prefix='step_{}'.format(args.step))
20 changes: 15 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import hparams as hp
import utils
import audio as Audio
import waveglow

def main(args):
torch.manual_seed(0)
Expand Down Expand Up @@ -55,7 +54,12 @@ def main(args):
os.makedirs(checkpoint_path)

# Load vocoder
wave_glow = utils.get_WaveGlow()
if hp.vocoder == 'melgan':
melgan = utils.get_melgan()
melgan.to(device)
elif hp.vocoder == 'waveglow':
waveglow = utils.get_waveglow()
waveglow.to(device)

# Init logger
log_path = hp.log_path
Expand Down Expand Up @@ -180,9 +184,15 @@ def main(args):
mel_postnet = mel_postnet_output[0, :length].detach().cpu().transpose(0, 1)
Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step)))
Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step)))
waveglow.inference.inference(mel_torch, wave_glow, os.path.join(synth_path, "step_{}_waveglow.wav".format(current_step)))
waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(synth_path, "step_{}_postnet_waveglow.wav".format(current_step)))
waveglow.inference.inference(mel_target_torch, wave_glow, os.path.join(synth_path, "step_{}_ground-truth_waveglow.wav".format(current_step)))

if hp.vocoder == 'melgan':
utils.melgan_infer(mel_torch, melgan, os.path.join(hp.test_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder)))
utils.melgan_infer(mel_postnet_torch, melgan, os.path.join(hp.test_path, 'step_{}_postnet_{}.wav'.format(current_step, hp.vocoder)))
utils.melgan_infer(mel_target_torch, melgan, os.path.join(hp.test_path, 'step_{}_ground-truch_{}.wav'.format(current_step, hp.vocoder)))
elif hp.vocoder == 'waveglow':
utils.waveglow_infer(mel_torch, waveglow, os.path.join(hp.test_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder)))
utils.waveglow_infer(mel_postnet_torch, waveglow, os.path.join(hp.test_path, 'step_{}_postnet_{}.wav'.format(current_step, hp.vocoder)))
utils.waveglow_infer(mel_target_torch, waveglow, os.path.join(hp.test_path, 'step_{}_ground-truch_{}.wav'.format(current_step, hp.vocoder)))

f0 = f0[0, :length].detach().cpu().numpy()
energy = energy[0, :length].detach().cpu().numpy()
Expand Down
32 changes: 25 additions & 7 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib
matplotlib.use("Agg")
from matplotlib import pyplot as plt
from scipy.io import wavfile
import os

import text
Expand Down Expand Up @@ -101,16 +102,33 @@ def get_mask_from_lengths(lengths, max_len=None):

return mask

def get_WaveGlow():
waveglow_path = hp.waveglow_path
wave_glow = torch.load(waveglow_path)['model']
wave_glow = wave_glow.remove_weightnorm(wave_glow)
wave_glow.cuda().eval()
for m in wave_glow.modules():
def get_waveglow():
waveglow = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'nvidia_waveglow')
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow.eval()
for m in waveglow.modules():
if 'Conv' in str(type(m)):
setattr(m, 'padding_mode', 'zeros')

return wave_glow
return waveglow

def waveglow_infer(mel, waveglow, path):
with torch.no_grad():
wav = waveglow.infer(mel, sigma=1.0) * hp.max_wav_value
wav = wav.squeeze().cpu().numpy()
wav = wav.astype('int16')
wavfile.write(path, hp.sampling_rate, wav)

def melgan_infer(mel, melgan, path):
with torch.no_grad():
wav = melgan.inference(mel).cpu().numpy()
wav = wav.astype('int16')
wavfile.write(path, hp.sampling_rate, wav)

def get_melgan():
melgan = torch.hub.load('seungwonpark/melgan', 'melgan')
melgan.eval()
return melgan

def pad_1D(inputs, PAD=0):

Expand Down
2 changes: 0 additions & 2 deletions waveglow/__init__.py

This file was deleted.

46 changes: 0 additions & 46 deletions waveglow/convert_model.py

This file was deleted.

Loading

0 comments on commit b8f5d02

Please sign in to comment.