Skip to content

Commit

Permalink
Add max steps control in toolbox
Browse files Browse the repository at this point in the history
  • Loading branch information
babysor committed Nov 6, 2021
1 parent c396792 commit 80aaf32
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
4 changes: 2 additions & 2 deletions synthesizer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def load(self):

def synthesize_spectrograms(self, texts: List[str],
embeddings: Union[np.ndarray, List[np.ndarray]],
return_alignments=False, style_idx=0, min_stop_token=5):
return_alignments=False, style_idx=0, min_stop_token=5, steps=2000):
"""
Synthesizes mel spectrograms from texts and speaker embeddings.
Expand Down Expand Up @@ -125,7 +125,7 @@ def synthesize_spectrograms(self, texts: List[str],
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)

# Inference
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx, min_stop_token=min_stop_token)
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx, min_stop_token=min_stop_token, steps=steps)
mels = mels.detach().cpu().numpy()
for m in mels:
# Trim silence from end of each spectrogram
Expand Down
2 changes: 1 addition & 1 deletion synthesizer/models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def forward(self, texts, mels, speaker_embedding):

return mel_outputs, linear, attn_scores, stop_outputs

def generate(self, x, speaker_embedding=None, steps=200, style_idx=0, min_stop_token=5):
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
self.eval()
device = next(self.parameters()).device # use same device as parameters

Expand Down
2 changes: 1 addition & 1 deletion toolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def synthesize(self):
embed = self.ui.selected_utterance.embed
embeds = [embed] * len(texts)
min_token = int(self.ui.token_slider.value())
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token)
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token, steps=int(self.ui.length_slider.value())*200)
breaks = [spec.shape[1] for spec in specs]
spec = np.concatenate(specs, axis=1)

Expand Down
13 changes: 13 additions & 0 deletions toolbox/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,19 @@ def __init__(self):
layout_seed.addWidget(self.token_value_label, 2, 1)
layout_seed.addWidget(self.token_slider, 2, 3)

self.length_slider = QSlider(Qt.Horizontal)
self.length_slider.setTickInterval(1)
self.length_slider.setFocusPolicy(Qt.NoFocus)
self.length_slider.setSingleStep(1)
self.length_slider.setRange(1, 10)
self.length_value_label = QLabel("2")
self.length_slider.setValue(2)
layout_seed.addWidget(QLabel("MaxLength(最大句长):"), 3, 0)

self.length_slider.valueChanged.connect(lambda s: self.length_value_label.setNum(s))
layout_seed.addWidget(self.length_value_label, 3, 1)
layout_seed.addWidget(self.length_slider, 3, 3)

gen_layout.addLayout(layout_seed)

self.loading_bar = QProgressBar()
Expand Down

0 comments on commit 80aaf32

Please sign in to comment.