Skip to content

Commit

Permalink
Implement cloning in API
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Jan 30, 2023
1 parent 335b8ed commit 7fddabc
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
33 changes: 26 additions & 7 deletions TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def __init__(
>>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False)
>>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
Example voice cloning with YourTTS in English, French and Portuguese:
>>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
>>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav")
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
Args:
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
model_path (str, optional): Path to the model checkpoint. Defaults to None.
Expand Down Expand Up @@ -144,8 +150,8 @@ def load_model_by_path(
use_cuda=gpu,
)

def _check_arguments(self, speaker: str = None, language: str = None):
if self.is_multi_speaker and speaker is None:
def _check_arguments(self, speaker: str = None, language: str = None, speaker_wav: str = None):
if self.is_multi_speaker and (speaker is None and speaker_wav is None):
raise ValueError("Model is multi-speaker but no speaker is provided.")
if self.is_multi_lingual and language is None:
raise ValueError("Model is multi-lingual but no language is provided.")
Expand All @@ -154,7 +160,7 @@ def _check_arguments(self, speaker: str = None, language: str = None):
if not self.is_multi_lingual and language is not None:
raise ValueError("Model is not multi-lingual but language is provided.")

def tts(self, text: str, speaker: str = None, language: str = None):
def tts(self, text: str, speaker: str = None, language: str = None, speaker_wav: str = None):
"""Convert text to speech.
Args:
Expand All @@ -166,22 +172,32 @@ def tts(self, text: str, speaker: str = None, language: str = None):
language (str, optional):
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
speaker_wav (str, optional):
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
Defaults to None.
"""
self._check_arguments(speaker=speaker, language=language)
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav)

wav = self.synthesizer.tts(
text=text,
speaker_name=speaker,
language_name=language,
speaker_wav=None,
speaker_wav=speaker_wav,
reference_wav=None,
style_wav=None,
style_text=None,
reference_speaker_name=None,
)
return wav

def tts_to_file(self, text: str, speaker: str = None, language: str = None, file_path: str = "output.wav"):
def tts_to_file(
self,
text: str,
speaker: str = None,
language: str = None,
speaker_wav: str = None,
file_path: str = "output.wav",
):
"""Convert text to speech.
Args:
Expand All @@ -193,8 +209,11 @@ def tts_to_file(self, text: str, speaker: str = None, language: str = None, file
language (str, optional):
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
speaker_wav (str, optional):
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
Defaults to None.
file_path (str, optional):
Output file path. Defaults to "output.wav".
"""
wav = self.tts(text=text, speaker=speaker, language=language)
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
self.synthesizer.save_wav(wav=wav, path=file_path)
2 changes: 1 addition & 1 deletion TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def tts(
text (str): input text.
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
language_name (str, optional): language id for multi-language models. Defaults to "".
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None.
style_wav ([type], optional): style waveform for GST. Defaults to None.
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
Expand Down
9 changes: 8 additions & 1 deletion tests/inference_tests/test_python_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import unittest

from tests import get_tests_output_path
from tests import get_tests_data_path, get_tests_output_path

from TTS.api import TTS

OUTPUT_PATH = os.path.join(get_tests_output_path(), "test_python_api.wav")
cloning_test_wav_path = os.path.join(get_tests_data_path(), "ljspeech/wavs/LJ001-0028.wav")


class TTSTest(unittest.TestCase):
Expand Down Expand Up @@ -34,3 +36,8 @@ def test_multi_speaker_multi_lingual_model(self):
self.assertTrue(tts.is_multi_lingual)
self.assertGreater(len(tts.speakers), 1)
self.assertGreater(len(tts.languages), 1)

def test_voice_cloning():
tts = TTS()
tts.load_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)

0 comments on commit 7fddabc

Please sign in to comment.