Skip to content

Commit

Permalink
[TTS] Add volume passthrough to fp for riva (NVIDIA#4167)
Browse files Browse the repository at this point in the history
* add volume passthrough to fp for riva

Signed-off-by: Jason <[email protected]>

* fix

Signed-off-by: Jason <[email protected]>

* omg a newline

Signed-off-by: Jason <[email protected]>
  • Loading branch information
blisc authored May 20, 2022
1 parent cfcb5f6 commit eba03a1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
9 changes: 6 additions & 3 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def _prepare_for_export(self, **kwargs):
"text": NeuralType(('B', 'T_text'), TokenIndex()),
"pitch": NeuralType(('B', 'T_text'), RegressionValuesType()),
"pace": NeuralType(('B', 'T_text'), optional=True),
"volume": NeuralType(('B', 'T_text')),
"speaker": NeuralType(('B'), Index()),
}
self._output_types = {
Expand All @@ -528,6 +529,7 @@ def _prepare_for_export(self, **kwargs):
"durs_predicted": NeuralType(('B', 'T_text'), TokenDurationType()),
"log_durs_predicted": NeuralType(('B', 'T_text'), TokenLogDurationType()),
"pitch_predicted": NeuralType(('B', 'T_text'), RegressionValuesType()),
"volume_aligned": NeuralType(('B', 'T_spec'), RegressionValuesType()),
}

def _export_teardown(self):
Expand Down Expand Up @@ -562,8 +564,9 @@ def input_example(self, max_batch=1, max_dim=256):
)
pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5
pace = torch.clamp((torch.randn(sz, device=par.device, dtype=torch.float32) + 1) * 0.1, min=0.01)
volume = torch.clamp((torch.randn(sz, device=par.device, dtype=torch.float32) + 1) * 0.1, min=0.01)

inputs = {'text': inp, 'pitch': pitch, 'pace': pace}
inputs = {'text': inp, 'pitch': pitch, 'pace': pace, 'volume': volume}

if self.fastpitch.speaker_emb is not None:
inputs['speaker'] = torch.randint(
Expand All @@ -572,5 +575,5 @@ def input_example(self, max_batch=1, max_dim=256):

return (inputs,)

def forward_for_export(self, text, pitch, pace, speaker=None):
return self.fastpitch.infer(text=text, pitch=pitch, pace=pace, speaker=speaker)
def forward_for_export(self, text, pitch, pace, volume, speaker=None):
return self.fastpitch.infer(text=text, pitch=pitch, pace=pace, volume=volume, speaker=speaker)
15 changes: 13 additions & 2 deletions nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def forward(
pitch,
)

def infer(self, *, text, pitch=None, speaker=None, pace=1.0):
def infer(self, *, text, pitch=None, speaker=None, pace=1.0, volume=None):
# Calculate speaker embedding
if self.speaker_emb is None or speaker is None:
spk_emb = 0
Expand All @@ -298,8 +298,19 @@ def infer(self, *, text, pitch=None, speaker=None, pace=1.0):

# Expand to decoder time dimension
len_regulated, dec_lens = regulate_len(durs_predicted, enc_out, pace)
volume_extended = None
if volume is not None:
volume_extended, _ = regulate_len(durs_predicted, volume.unsqueeze(-1), pace)
volume_extended = volume_extended.squeeze(-1).float()

# Output FFT
dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens)
spect = self.proj(dec_out).transpose(1, 2)
return spect.to(torch.float), dec_lens, durs_predicted, log_durs_predicted, pitch_predicted
return (
spect.to(torch.float),
dec_lens,
durs_predicted,
log_durs_predicted,
pitch_predicted,
volume_extended,
)

0 comments on commit eba03a1

Please sign in to comment.