Skip to content

Commit

Permalink
Make lint
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Feb 25, 2022
1 parent 146fbfd commit 2194095
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions tests/tts_tests/test_vits.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import os
import unittest
from TTS.utils.logging.tensorboard_logger import TensorboardLogger

import torch

Expand All @@ -11,6 +10,7 @@
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.logging.tensorboard_logger import TensorboardLogger

LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
Expand Down Expand Up @@ -337,7 +337,7 @@ def _check_parameter_changes(model, model_ref):
count += 1

def _create_batch(self, config, batch_size):
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(config, batch_size)
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size)
batch = {}
batch["text_input"] = input_dummy
batch["text_lengths"] = input_lengths
Expand Down Expand Up @@ -441,22 +441,26 @@ def test_init_from_config(self):
self.assertEqual(model.num_speakers, 2)
self.assertTrue(hasattr(model, "emb_g"))

config = VitsConfig(model_args=VitsArgs(
num_chars=32,
num_speakers=2,
use_speaker_embedding=True,
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
))
config = VitsConfig(
model_args=VitsArgs(
num_chars=32,
num_speakers=2,
use_speaker_embedding=True,
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
)
)
model = Vits.init_from_config(config, verbose=False).to(device)
self.assertEqual(model.num_speakers, 10)
self.assertTrue(hasattr(model, "emb_g"))

config = VitsConfig(model_args=VitsArgs(
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
))
config = VitsConfig(
model_args=VitsArgs(
num_chars=32,
use_d_vector_file=True,
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
)
model = Vits.init_from_config(config, verbose=False).to(device)
self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g"))
Expand Down

0 comments on commit 2194095

Please sign in to comment.