Skip to content

Commit

Permalink
Fix glow_tts imports
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Sep 10, 2021
1 parent 570d597 commit a89eb12
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from TTS.tts.configs import GlowTTSConfig
from TTS.tts.layers.glow_tts.decoder import Decoder
from TTS.tts.layers.glow_tts.encoder import Encoder
from TTS.tts.utils.helpers import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
Expand Down Expand Up @@ -133,7 +132,7 @@ def compute_outputs(attn, o_mean, o_log_scale, x_mask):
return y_mean, y_log_scale, o_attn_dur

def forward(
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
Expand Down Expand Up @@ -185,7 +184,7 @@ def forward(

@torch.no_grad()
def inference_with_MAS(
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
It's similar to the teacher forcing in Tacotron.
Expand Down Expand Up @@ -246,7 +245,7 @@ def inference_with_MAS(

@torch.no_grad()
def decoder_inference(
self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
Expand Down Expand Up @@ -278,7 +277,9 @@ def decoder_inference(
return outputs

@torch.no_grad()
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
x_lengths = aux_input["x_lengths"]
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None

Expand Down Expand Up @@ -331,7 +332,13 @@ def train_step(self, batch: dict, criterion: nn.Module):
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]

outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids})
outputs = self.forward(
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
)

loss_dict = criterion(
outputs["model_outputs"],
Expand Down

0 comments on commit a89eb12

Please sign in to comment.