Skip to content

Commit

Permalink
fix glow-tts inference()
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Jun 28, 2021
1 parent 8258299 commit 25238e0
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 8 deletions.
5 changes: 4 additions & 1 deletion TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ def decoder_inference(
return outputs

@torch.no_grad()
def inference(self, x, x_lengths, cond_input={"d_vectors": None}): # pylint: disable=dangerous-default-value
def inference(
self, x, cond_input={"x_lengths": None, "d_vectors": None}
): # pylint: disable=dangerous-default-value
x_lengths = cond_input["x_lengths"]
g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None
if g is not None:
if self.d_vector_dim:
Expand Down
31 changes: 29 additions & 2 deletions TTS/tts/utils/synthesis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from typing import Dict

import numpy as np
import pkg_resources
import torch
from torch import nn

from .text import phoneme_to_sequence, text_to_sequence

Expand Down Expand Up @@ -65,9 +67,34 @@ def compute_style_mel(style_wav, ap, cuda=False):
return style_mel


def run_model_torch(model, inputs, speaker_id=None, style_mel=None, d_vector=None):
def run_model_torch(
model: nn.Module,
inputs: torch.Tensor,
speaker_id: int = None,
style_mel: torch.Tensor = None,
d_vector: torch.Tensor = None,
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
Args:
model (nn.Module): The model to run inference.
inputs (torch.Tensor): Input tensor with character ids.
speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None.
style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None.
d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None.
Returns:
Dict: model outputs.
"""
input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
outputs = model.inference(
inputs, cond_input={"speaker_ids": speaker_id, "d_vector": d_vector, "style_mel": style_mel}
inputs,
cond_input={
"x_lengths": input_lengths,
"speaker_ids": speaker_id,
"d_vectors": d_vector,
"style_mel": style_mel,
},
)
return outputs

Expand Down
3 changes: 2 additions & 1 deletion tests/inference_tests/test_synthesizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import unittest

from tests import get_tests_output_path
from TTS.config import load_config
from TTS.tts.models import setup_model
from TTS.tts.utils.io import save_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.synthesizer import Synthesizer

from .. import get_tests_output_path


class SynthesizerTest(unittest.TestCase):
# pylint: disable=R0201
Expand Down
4 changes: 1 addition & 3 deletions tests/tts_tests/test_tacotron2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,7 @@ def test_train_step():
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to(
device
)
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
Expand Down
1 change: 0 additions & 1 deletion tests/tts_tests/test_tacotron2_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")


config = Tacotron2Config(
r=5,
batch_size=8,
Expand Down

0 comments on commit 25238e0

Please sign in to comment.