Skip to content

Commit

Permalink
linter fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Mar 8, 2021
1 parent 55fc50b commit ee71eb4
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 14 deletions.
5 changes: 2 additions & 3 deletions TTS/bin/find_unique_chars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@


def main():
# pylint: disable=bad-continuation
parser = argparse.ArgumentParser(description='''Find all the unique characters or phonemes in a dataset.\n\n'''

'''Target dataset must be defined in TTS.tts.datasets.preprocess\n\n'''\

'''
Example runs:
python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv
''',
formatter_class=RawTextHelpFormatter)
''', formatter_class=RawTextHelpFormatter)

parser.add_argument(
'--dataset',
Expand Down
6 changes: 2 additions & 4 deletions TTS/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
import glob
import os
import re
import json

from TTS.tts.utils.generic_utils import check_config_tts
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
from TTS.utils.io import (copy_model_files, load_config,
save_characters_to_config)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.tts.utils.text.symbols import parse_symbols


def parse_arguments(argv):
Expand Down
3 changes: 2 additions & 1 deletion TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def _download_gdrive_file(self, gdrive_idx, output):
"""Download files from GDrive using their file ids"""
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)

def _download_zip_file(self, file_url, output):
@staticmethod
def _download_zip_file(file_url, output):
"""Download the target zip file and extract the files
to a folder with the same name as the zip file."""
r = requests.get(file_url)
Expand Down
11 changes: 6 additions & 5 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak
dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak-ng
import torch

from TTS.utils.synthesizer import Synthesizer
from TTS.utils.manage import ModelManager


def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder_models/en/ljspeech/mulitband-melgan', use_cuda=False):
def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name=None, use_cuda=False):
"""TTS entry point for PyTorch Hub that provides a Synthesizer object to synthesize speech from a give text.
Example:
Expand All @@ -15,16 +15,17 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder
Args:
model_name (str, optional): One of the model names from .model.json. Defaults to 'tts_models/en/ljspeech/tacotron2-DCA'.
vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/mulitband-melgan'.
vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/multiband-melgan'.
pretrained (bool, optional): [description]. Defaults to True.
Returns:
TTS.utils.synthesizer.Synthesizer: Synthesizer object wrapping both vocoder and tts models.
"""
manager = ModelManager()

model_path, config_path = manager.download_model(model_name)
vocoder_path, vocoder_config_path = manager.download_model(vocoder_name)
model_path, config_path, model_item = manager.download_model(model_name)
vocoder_name = model_item['default_vocoder'] if vocoder_name is None else vocoder_name
vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)

# create synthesizer
synt = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, use_cuda)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_demo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _create_random_model(self):
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
model = setup_model(num_chars, 0, config)
output_path = os.path.join(get_tests_output_path())
save_checkpoint(model, None, 10, 10, 1, output_path)
save_checkpoint(model, None, 10, 10, 1, output_path, None)

def test_in_out(self):
self._create_random_model()
Expand Down

0 comments on commit ee71eb4

Please sign in to comment.