Skip to content

Commit

Permalink
Fix train_tts.py and uncomment code (coqui-ai#1051)
Browse files Browse the repository at this point in the history
* Fix SE loading and language embedding logic

* remove trailing white space

* Uncomment resmapling code for SCL
  • Loading branch information
WeberJulian authored Jan 3, 2022
1 parent 58c38de commit e1accb6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
16 changes: 12 additions & 4 deletions TTS/bin/train_tts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import torch

from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
from TTS.trainer import Trainer, TrainingArgs
Expand Down Expand Up @@ -53,15 +54,22 @@ def main():
else:
config.num_speakers = speaker_manager.num_speakers
elif check_config_and_model_args(config, "use_d_vector_file", True):
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True):
speaker_manager = SpeakerManager(
d_vectors_file_path=config.model_args.d_vector_file,
encoder_model_path=config.model_args.speaker_encoder_model_path,
encoder_config_path=config.model_args.speaker_encoder_config_path,
use_cuda=torch.cuda.is_available(),
)
else:
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
config.num_speakers = speaker_manager.num_speakers
if hasattr(config, "model_args"):
config.model_args.num_speakers = speaker_manager.num_speakers
else:
config.num_speakers = speaker_manager.num_speakers
else:
speaker_manager = None

if hasattr(config, "use_language_embedding") and config.use_language_embedding:
if check_config_and_model_args(config, "use_language_embedding", True):
language_manager = LanguageManager(config=config)
if hasattr(config, "model_args"):
config.model_args.num_languages = language_manager.num_languages
Expand Down
22 changes: 7 additions & 15 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

# import torchaudio
import torchaudio
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
Expand Down Expand Up @@ -419,21 +419,12 @@ def init_multispeaker(self, config: Coqpit):
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
):
# TODO: change this with torchaudio Resample
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
self.config.audio["sample_rate"],
self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
)
# pylint: disable=W0101,W0105
""" self.audio_transform = torchaudio.transforms.Resample(
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
else:
self.audio_transform = None
"""
else:
self.audio_transform = None

def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
Expand All @@ -458,6 +449,7 @@ def init_multilingual(self, config: Coqpit):
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)

if self.args.use_language_embedding and self.language_manager:
print(" > initialization of language-embedding layers.")
self.num_languages = self.language_manager.num_languages
self.embedded_language_dim = self.args.embedded_language_dim
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
Expand Down Expand Up @@ -643,8 +635,8 @@ def forward(

# resample audio to speaker encoder sample_rate
# pylint: disable=W0105
"""if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)"""
if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)

pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)

Expand Down

0 comments on commit e1accb6

Please sign in to comment.