Skip to content

Commit

Permalink
Merge pull request coqui-ai#800 from coqui-ai/forward_tts
Browse files Browse the repository at this point in the history
Forward TTS implementation
  • Loading branch information
erogol authored Sep 13, 2021
2 parents b0b96b4 + 69dd36e commit aed9a32
Show file tree
Hide file tree
Showing 54 changed files with 24,517 additions and 1,050 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
.DEFAULT_GOAL := help
.PHONY: test system-deps dev-deps deps style lint install help
.PHONY: test system-deps dev-deps deps style lint install help docs

help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
Expand Down Expand Up @@ -45,3 +45,6 @@ deps: ## install 🐸 requirements.

install: ## install 🐸 TTS for development.
pip install -e .[all]

docs: ## build the docs
$(MAKE) -C docs clean && $(MAKE) -C docs html
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
- FastPitch: [paper](https://arxiv.org/pdf/2006.06873.pdf)
- FastSpeech: [paper](https://arxiv.org/abs/1905.09263)

### End-to-End Models
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
Expand Down
9 changes: 0 additions & 9 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,6 @@
"license": "MPL",
"contact": "[email protected]"
},
"speedy-speech-wn": {
"description": "Speedy Speech model with wavenet decoder.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ljspeech--speedy-speech-wn.zip",
"default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan",
"commit": "77b6145",
"author": "Eren Gölge @erogol",
"license": "MPL",
"contact": "[email protected]"
},
"vits": {
"description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.",
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.2.0/tts_models--en--ljspeech--vits.zip",
Expand Down
26 changes: 14 additions & 12 deletions TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
from TTS.utils.io import load_fsspec

use_cuda = torch.cuda.is_available()

Expand Down Expand Up @@ -77,14 +76,14 @@ def set_filename(wav_path, out_path):

def format_data(data):
# setup input data
text_input = data['text']
text_lengths = data['text_lengths']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
item_idx = data['item_idxs']
d_vectors = data['d_vectors']
speaker_ids = data['speaker_ids']
attn_mask = data['attns']
text_input = data["text"]
text_lengths = data["text_lengths"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
item_idx = data["item_idxs"]
d_vectors = data["d_vectors"]
speaker_ids = data["speaker_ids"]
attn_mask = data["attns"]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())

Expand Down Expand Up @@ -133,7 +132,11 @@ def inference(
elif d_vectors is not None:
speaker_c = d_vectors
outputs = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
)
model_output = outputs["model_outputs"]
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
Expand Down Expand Up @@ -239,8 +242,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = setup_model(c)

# restore model
checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.load_checkpoint(c, args.checkpoint_path, eval=True)

if use_cuda:
model.cuda()
Expand Down
9 changes: 5 additions & 4 deletions TTS/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def __init__(
if self.args.continue_path:
if isinstance(self.scheduler, list):
for scheduler in self.scheduler:
scheduler.last_epoch = self.restore_step
if scheduler is not None:
scheduler.last_epoch = self.restore_step
else:
self.scheduler.last_epoch = self.restore_step

Expand Down Expand Up @@ -662,6 +663,7 @@ def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_ti
lrs = {"current_lr": current_lr}

# log run-time stats
loss_dict.update(lrs)
loss_dict.update(
{
"step_time": round(step_time, 4),
Expand Down Expand Up @@ -1125,7 +1127,7 @@ def get_last_checkpoint(path: str) -> Tuple[str, str]:
last_model_num = model_num
last_model = file_name

# if there is not checkpoint found above
# if there is no checkpoint found above
# find the checkpoint with the latest
# modification date.
key_file_names = [fn for fn in file_names if key in fn]
Expand All @@ -1144,7 +1146,7 @@ def get_last_checkpoint(path: str) -> Tuple[str, str]:
last_models["checkpoint"] = last_models["best_model"]
elif "best_model" not in last_models: # no best model
# this shouldn't happen, but let's handle it just in case
last_models["best_model"] = None
last_models["best_model"] = last_models["checkpoint"]
# finally check if last best model is more recent than checkpoint
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
last_models["checkpoint"] = last_models["best_model"]
Expand Down Expand Up @@ -1180,7 +1182,6 @@ def process_args(args, config=None):
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model

# init config if not already defined
if config is None:
if args.config_path:
Expand Down
47 changes: 38 additions & 9 deletions TTS/tts/configs/fast_pitch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import List

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.fast_pitch import FastPitchArgs
from TTS.tts.models.forward_tts import ForwardTTSArgs


@dataclass
class FastPitchConfig(BaseTTSConfig):
"""Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models.
"""Configure `ForwardTTS` as FastPitch model.
Example:
Expand All @@ -18,6 +18,10 @@ class FastPitchConfig(BaseTTSConfig):
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
Expand All @@ -36,22 +40,43 @@ class FastPitchConfig(BaseTTSConfig):
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
noam_schedule (bool):
enable / disable the use of Noam LR scheduler. Defaults to False.
d_vector_dim (int):
Dimension of the external speaker embeddings. Defaults to 0.
optimizer (str):
Name of the model optimizer. Defaults to `Adam`.
optimizer_params (dict):
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
warmup_steps (int):
Number of warm-up steps for the Noam scheduler. Defaults 4000.
lr_scheduler (str):
Name of the learning rate scheduler. Defaults to `Noam`.
lr_scheduler_params (dict):
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
lr (float):
Initial learning rate. Defaults to `1e-3`.
grad_clip (float):
Gradient norm clipping value. Defaults to `5.0`.
spec_loss_type (str):
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
duration_loss_type (str):
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
use_ssim_loss (bool):
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
wd (float):
Weight decay coefficient. Defaults to `1e-7`.
ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
huber_loss_alpha (float):
dur_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float):
Expand All @@ -74,8 +99,10 @@ class FastPitchConfig(BaseTTSConfig):
"""

model: str = "fast_pitch"
base_model: str = "forward_tts"

# model specific params
model_args: FastPitchArgs = field(default_factory=FastPitchArgs)
model_args: ForwardTTSArgs = ForwardTTSArgs()

# multi-speaker settings
use_speaker_embedding: bool = False
Expand All @@ -92,11 +119,13 @@ class FastPitchConfig(BaseTTSConfig):
grad_clip: float = 5.0

# loss params
spec_loss_type: str = "mse"
duration_loss_type: str = "mse"
use_ssim_loss: bool = True
ssim_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
pitch_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000
Expand Down
151 changes: 151 additions & 0 deletions TTS/tts/configs/fast_speech_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from dataclasses import dataclass, field
from typing import List

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs


@dataclass
class FastSpeechConfig(BaseTTSConfig):
"""Configure `ForwardTTS` as FastSpeech model.
Example:
>>> from TTS.tts.configs import FastSpeechConfig
>>> config = FastSpeechConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit):
Model class arguments. Check `FastSpeechArgs` for more details. Defaults to `FastSpeechArgs()`.
data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10.
use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.
use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
d_vector_dim (int):
Dimension of the external speaker embeddings. Defaults to 0.
optimizer (str):
Name of the model optimizer. Defaults to `Adam`.
optimizer_params (dict):
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
lr_scheduler (str):
Name of the learning rate scheduler. Defaults to `Noam`.
lr_scheduler_params (dict):
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
lr (float):
Initial learning rate. Defaults to `1e-3`.
grad_clip (float):
Gradient norm clipping value. Defaults to `5.0`.
spec_loss_type (str):
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
duration_loss_type (str):
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
use_ssim_loss (bool):
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
wd (float):
Weight decay coefficient. Defaults to `1e-7`.
ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
dur_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""

model: str = "fast_speech"
base_model: str = "forward_tts"

# model specific params
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=False)

# multi-speaker settings
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_dim: int = 0

# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4
grad_clip: float = 5.0

# loss params
spec_loss_type: str = "mse"
duration_loss_type: str = "mse"
use_ssim_loss: bool = True
ssim_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
pitch_loss_alpha: float = 0.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000

# overrides
min_seq_len: int = 13
max_seq_len: int = 200
r: int = 1 # DO NOT CHANGE

# dataset configs
compute_f0: bool = True
f0_cache_path: str = None

# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)
Loading

0 comments on commit aed9a32

Please sign in to comment.