Skip to content

Commit

Permalink
Update to PyTorch 1.6, add inference-time duration/pitch/energy control
Browse files Browse the repository at this point in the history
  • Loading branch information
ming024 committed Dec 8, 2020
1 parent 41baf3b commit 5501a42
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 41 deletions.
65 changes: 36 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# FastSpeech 2 - Pytorch Implementation
# FastSpeech 2 - PyTorch Implementation

This is a Pytorch implementation of Microsoft's text-to-speech system [**FastSpeech 2: Fast and High-Quality End-to-End Text to Speech**](https://arxiv.org/abs/2006.04558). This project is based on [xcmyz's implementation](https://github.com/xcmyz/FastSpeech) of FastSpeech. Feel free to use/modify the code. Any suggestion for improvement is appreciated.
This is a PyTorch implementation of Microsoft's text-to-speech system [**FastSpeech 2: Fast and High-Quality End-to-End Text to Speech**](https://arxiv.org/abs/2006.04558).
This project is based on [xcmyz's implementation](https://github.com/xcmyz/FastSpeech) of FastSpeech. Feel free to use/modify the code.
Any suggestion for improvement is appreciated.

This repository contains only FastSpeech 2 but FastSpeech 2s so far. I will update it once I reproduce FastSpeech 2s, the end-to-end version of FastSpeech2, successfully.
This repository contains only FastSpeech 2 but FastSpeech 2s so far.
I will update it once I reproduce FastSpeech 2s, the end-to-end version of FastSpeech2, successfully.

![](./model.png)

# Audio Samples
Audio samples generated by this implementation can be found [here](https://ming024.github.io/FastSpeech2/).
- The model used to generate these samples is trained for 300k steps on [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset.
- The model used to generate these samples is trained for 300k steps on the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset.
- Audio samples are converted from mel-spectrogram to raw waveform via [NVIDIA's pretrained WaveGlow](https://github.com/NVIDIA/waveglow) and [seungwonpark's pretrained MelGAN](https://github.com/seungwonpark/melgan).

# Quickstart
Expand All @@ -18,12 +21,6 @@ You can install the python dependencies with
```
pip3 install -r requirements.txt
```
Noticeably, because I use a new functionality ``torch.bucketize``, which is only supported in PyTorch 1.6, you have to install the nightly build by
```
pip3 install --pre torch==1.6.0.dev20200428 -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html
```

Since PyTorch 1.6 is still unstable, it is suggested that Python virtual environment should be used.

## Synthesis

Expand All @@ -33,7 +30,8 @@ Your can run
```
python3 synthesis.py --step 300000
```
to generate any utterances you wish to. The generated utterances will be put in the ``results/`` directory.
to generate any desired utterances.
The generated utterances will be put in the ``results/`` directory.

Here is a generated spectrogram of the sentence "Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition"
![](./synth/LJSpeech/step_300000.png)
Expand All @@ -43,22 +41,24 @@ For CPU inference please refer to this [colab tutorial](https://colab.research.g
# Training

## Datasets
This project supports two datasets:
- [LJSpeech](https://keithito.com/LJ-Speech-Dataset/): consisting of 13100 short audio clips of a single female speaker reading passages from 7 non-fiction books, approximately 24 hours in total.
- [Blizzard2013](http://www.cstr.ed.ac.uk/projects/blizzard/2013/lessac_blizzard2013/): a female speaker reading 10 audio books. The prosody variance are greater than the LJSpeech dataset. Only the 9741 segmented utterances are used in this project.
We use the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) English dataset, which consists of 13100 short audio clips of a single female speaker reading passages from 7 non-fiction books, approximately 24 hours in total, to train the entire model end-to-end.

After downloading the dataset, extract the compressed files, you have to modify the ``hp.data_path`` and some other parameters in ``hparams.py``. Default parameters are for the LJSpeech dataset.
After downloading the dataset and extracting the compressed files, you have to modify the ``hp.data_path`` and some other parameters in ``hparams.py``.
Default parameters are for the LJSpeech dataset.

## Preprocessing

As described in the paper, [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/)(MFA) is used to obtain the alignments between the utterances and the phoneme sequences. Alignments for the LJSpeech dataset is provided [here](https://drive.google.com/file/d/1ukb8o-SnqhXCxq7drI3zye3tZdrGvQDA/view?usp=sharing). You have to put the ``TextGrid.zip`` file in your ``hp.preprocessed_path/`` and extract the files before you continue.
As described in the paper, [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/)(MFA) is used to obtain the alignments between the utterances and the phoneme sequences.
Alignments for the LJSpeech dataset is provided [here](https://drive.google.com/file/d/1ukb8o-SnqhXCxq7drI3zye3tZdrGvQDA/view?usp=sharing).
You have to put the ``TextGrid.zip`` file in your ``hp.preprocessed_path/`` and extract the files before you continue.

After that, run the preprocessing script by
```
python3 preprocess.py
```

Alternately, you can align the corpus by yourself. First, download the MFA package and the pretrained lexicon file. (We use LibriSpeech lexicon instead of the G2p\_en python package proposed in the paper)
Alternately, you can align the corpus by yourself.
First download the MFA package and the pretrained lexicon file. (We use LibriSpeech lexicon instead of the G2p\_en python package proposed in the paper)

```
wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.1.0-beta.2/montreal-forced-aligner_linux.tar.gz
Expand All @@ -73,18 +73,19 @@ Then prepare some necessary files required by the MFA.
python3 prepare_align.py
```

Running MFA and put the .TextGrid files in your ``hp.preprocessed_path``.
Run the MFA and put the .TextGrid files in your ``hp.preprocessed_path``.
```
# Replace $DATA_PATH and $PREPROCESSED_PATH with ./LJSpeech-1.1/wavs and ./preprocessed/LJSpeech/TextGrid, for example
./montreal-forced-aligner/bin/mfa_align $YOUR_DATA_PATH montreal-forced-aligner/pretrained_models/librispeech-lexicon.txt english $YOUR_PREPROCESSED_PATH -j 8
```

Remember to run the preprocessing script.
And remember to run the preprocessing script.
```
python3 preprocess.py
```

After preprocessing, you will get a ``stat.txt`` file in your ``hp.preprocessed_path/``, recording the maximum and minimum values of the fundamental frequency and energy values throughout the entire corpus. You have to modify the f0 and energy parameters in the ``hparams.py`` according to the content of ``stat.txt``.
After preprocessing, you will get a ``stat.txt`` file in your ``hp.preprocessed_path/``, recording the maximum and minimum values of the fundamental frequency and energy values throughout the entire corpus.
You have to modify the f0 and energy parameters in the ``hparams.py`` according to the content of ``stat.txt``.

## Training

Expand All @@ -95,15 +96,17 @@ python3 train.py

The model takes less than 10k steps (less than 1 hour on my GTX1080 GPU) of training to generate audio samples with acceptable quality, which is much more efficient than the autoregressive models such as Tacotron2.

There might be some room for improvement for this repository. For example, I just simply add up the duration loss, f0 loss, energy loss and mel loss without any weighting.
There might be some room for improvement for this repository.
For example, I just simply add up the duration loss, f0 loss, energy loss and mel loss without any weighting.

# TensorBoard

The TensorBoard loggers are stored in the ``log/hp.dataset/`` directory. Use
```
tensorboard --logdir log/hp.dataset/
```
to serve the TensorBoard on your localhost. Here is an example training the model on LJSpeech for 400k steps.
to serve the TensorBoard on your localhost.
Here is an example training the model on LJSpeech for 400k steps.

![](./tensorboard.png)

Expand All @@ -112,22 +115,26 @@ to serve the TensorBoard on your localhost. Here is an example training the mode
## Implementation Issues

There are several differences between my implementation and the paper.
- The paper includes punctuations in the transcripts. However, MFA discards puntuations by default and I haven't found a way to solve it. During inference, I replace all puntuations with the ``sp`` (short-pause) phone labels.
- The paper includes punctuations in the transcripts.
However, MFA discards punctuations by default and I haven't found a way to solve it.
During inference, I replace all punctuations with the ``sp`` (short-pause) phone labels.
- Following [xcmyz's implementation](https://github.com/xcmyz/FastSpeech), I use an additional Tacotron-2-styled postnet after the FastSpeech decoder, which is not used in the original paper.
- The [transformer paper](https://arxiv.org/abs/1706.03762) suggests to use dropout after the input and positional embedding. I find that this trick does not make any observable difference so I do not use dropout for potitional embedding.
- The paper suggest to use L1 loss for mel loss and L2 loss for variance predictor losses. But I find it easier to train the model with L2 mel loss and L1 variance adaptor losses, for unknown reason.
- I use gradient clipping in the training.
- The [transformer paper](https://arxiv.org/abs/1706.03762) suggests to use dropout after the input and positional embedding.
I find that this trick does not make any observable difference so I do not use dropout for positional embedding.
- The paper suggest to use L1 loss for mel loss and L2 loss for variance predictor losses.
But I find it easier to train the model with L2 mel loss and L1 variance adaptor losses.
- Gradient clipping is used in the training.

Some tips for training this model.
- You can set the ``hp.acc_steps`` paremeter if you wish to train with a large batchsize on a GPU with limited memory.
- In my experience, carefully masking out the padded parts in loss computation and in model forward parts can largely improve the performance.
- You can set the ``hp.acc_steps`` parameter if you wish to train with a large batchsize on a GPU with limited memory.
- In my experience, carefully masking out the padded parts in loss computation and in model forward parts largely improves the performance.

Please inform me if you find any mistake in this repo, or any useful tip to train the FastSpeech2 model.

## TODO
- Try difference weights for the loss terms.
- Evaluate the quality of the synthesized audio over the validation set.
- Multi-speaker or transfer learning experiment.
- Multi-speaker, voice cloning, or transfer learning experiment.
- Implement FastSpeech 2s.

# References
Expand Down
6 changes: 3 additions & 3 deletions fastspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ def __init__(self, use_postnet=True):
if self.use_postnet:
self.postnet = PostNet()

def forward(self, src_seq, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None):
def forward(self, src_seq, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None, d_control=1.0, p_control=1.0, e_control=1.0):
src_mask = get_mask_from_lengths(src_len, max_src_len)
mel_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None

encoder_output = self.encoder(src_seq, src_mask)
if d_target is not None:
variance_adaptor_output, d_prediction, p_prediction, e_prediction, _, _ = self.variance_adaptor(
encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len)
encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len, d_control, p_control, e_control)
else:
variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor(
encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len)
encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len, d_control, p_control, e_control)

decoder_output = self.decoder(variance_adaptor_output, mel_mask)
mel_output = self.mel_linear(decoder_output)
Expand Down
10 changes: 6 additions & 4 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,33 @@ def __init__(self):
self.pitch_predictor = VariancePredictor()
self.energy_predictor = VariancePredictor()

self.pitch_bins = nn.Parameter(torch.exp(torch.linspace(np.log(hp.f0_min), np.log(hp.f0_max), hp.n_bins-1)))
self.energy_bins = nn.Parameter(torch.linspace(hp.energy_min, hp.energy_max, hp.n_bins-1))
self.pitch_bins = nn.Parameter(torch.exp(torch.linspace(np.log(hp.f0_min), np.log(hp.f0_max), hp.n_bins-1)), requires_grad=False)
self.energy_bins = nn.Parameter(torch.linspace(hp.energy_min, hp.energy_max, hp.n_bins-1), requires_grad=False)
self.pitch_embedding = nn.Embedding(hp.n_bins, hp.encoder_hidden)
self.energy_embedding = nn.Embedding(hp.n_bins, hp.encoder_hidden)

def forward(self, x, src_mask, mel_mask=None, duration_target=None, pitch_target=None, energy_target=None, max_len=None):
def forward(self, x, src_mask, mel_mask=None, duration_target=None, pitch_target=None, energy_target=None, max_len=None, d_control=1.0, p_control=1.0, e_control=1.0):

log_duration_prediction = self.duration_predictor(x, src_mask)
if duration_target is not None:
x, mel_len = self.length_regulator(x, duration_target, max_len)
else:
duration_rounded = torch.clamp(torch.round(torch.exp(log_duration_prediction)-hp.log_offset), min=0)
duration_rounded = torch.clamp((torch.round(torch.exp(log_duration_prediction)-hp.log_offset)*d_control), min=0)
x, mel_len = self.length_regulator(x, duration_rounded, max_len)
mel_mask = utils.get_mask_from_lengths(mel_len)

pitch_prediction = self.pitch_predictor(x, mel_mask)
if pitch_target is not None:
pitch_embedding = self.pitch_embedding(torch.bucketize(pitch_target, self.pitch_bins))
else:
pitch_prediction = pitch_prediction*p_control
pitch_embedding = self.pitch_embedding(torch.bucketize(pitch_prediction, self.pitch_bins))

energy_prediction = self.energy_predictor(x, mel_mask)
if energy_target is not None:
energy_embedding = self.energy_embedding(torch.bucketize(energy_target, self.energy_bins))
else:
energy_prediction = energy_prediction*e_control
energy_embedding = self.energy_embedding(torch.bucketize(energy_prediction, self.energy_bins))

x = x + pitch_embedding + energy_embedding
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
numpy == 1.19.0
torch == 1.6.0
tgt == 1.4.4
scipy == 1.5.0
pyworld == 0.2.10
Expand Down
14 changes: 9 additions & 5 deletions synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def get_FastSpeech2(num):
model.eval()
return model

def synthesize(model, waveglow, melgan, text, sentence, prefix=''):
def synthesize(model, waveglow, melgan, text, sentence, prefix='', duration_control=1.0, pitch_control=1.0, energy_control=1.0):
sentence = sentence[:200] # long filename will result in OS Error

src_len = torch.from_numpy(np.array([text.shape[1]])).to(device)

mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(text, src_len)
mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(text, src_len, d_control=duration_control, p_control=pitch_control, e_control=energy_control)

mel_torch = mel.transpose(1, 2).detach()
mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
Expand All @@ -69,6 +69,9 @@ def synthesize(model, waveglow, melgan, text, sentence, prefix=''):
# Test
parser = argparse.ArgumentParser()
parser.add_argument('--step', type=int, default=30000)
parser.add_argument('--duration_control', type=float, default=1.0)
parser.add_argument('--pitch_control', type=float, default=1.0)
parser.add_argument('--energy_control', type=float, default=1.0)
args = parser.parse_args()

sentences = [
Expand All @@ -92,6 +95,7 @@ def synthesize(model, waveglow, melgan, text, sentence, prefix=''):
elif hp.vocoder == 'waveglow':
waveglow = utils.get_waveglow()

for sentence in sentences:
text = preprocess(sentence)
synthesize(model, waveglow, melgan, text, sentence, prefix='step_{}'.format(args.step))
with torch.no_grad():
for sentence in sentences:
text = preprocess(sentence)
synthesize(model, waveglow, melgan, text, sentence, 'step_{}'.format(args.step), args.duration_control, args.pitch_control, args.energy_control)

0 comments on commit 5501a42

Please sign in to comment.