Skip to content

Commit

Permalink
REBASED: Add support for the speaker encoder training using torch spe…
Browse files Browse the repository at this point in the history
…ctrograms (coqui-ai#1348)

* Add support for the speaker encoder training using torch spectrograms

* Remove useless function in speaker encoder dataset class
  • Loading branch information
Edresson authored Mar 10, 2022
1 parent 07d96f7 commit f381e29
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 23 deletions.
1 change: 1 addition & 0 deletions TTS/bin/train_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
sample_from_storage_p=c.storage["sample_from_storage_p"],
verbose=verbose,
augmentation_config=c.audio_augmentation,
use_torch_spec=c.model_params.get("use_torch_spec", False),
)

# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
Expand Down
26 changes: 8 additions & 18 deletions TTS/speaker_encoder/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
skip_speakers=False,
verbose=False,
augmentation_config=None,
use_torch_spec=None,
):
"""
Args:
Expand All @@ -37,6 +38,7 @@ def __init__(
self.skip_speakers = skip_speakers
self.ap = ap
self.verbose = verbose
self.use_torch_spec = use_torch_spec
self.__parse_items()
storage_max_size = storage_size * num_speakers_in_batch
self.storage = Storage(
Expand Down Expand Up @@ -72,22 +74,6 @@ def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio

def load_data(self, idx):
text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len

assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx]["audio_file"]

sample = {
"mel": mel,
"item_idx": self.items[idx]["audio_file"],
"speaker_name": speaker_name,
}
return sample

def __parse_items(self):
self.speaker_to_utters = {}
for i in self.items:
Expand Down Expand Up @@ -241,8 +227,12 @@ def collate_fn(self, batch):
self.gaussian_augmentation_config["max_amplitude"],
size=len(wav),
)
mel = self.ap.melspectrogram(wav)
feats_.append(torch.FloatTensor(mel))

if not self.use_torch_spec:
mel = self.ap.melspectrogram(wav)
feats_.append(torch.FloatTensor(mel))
else:
feats_.append(torch.FloatTensor(wav))

labels.append(torch.LongTensor(labels_))
feats.extend(feats_)
Expand Down
10 changes: 5 additions & 5 deletions TTS/tts/datasets/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,21 +334,21 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
return items


def vctk_old(root_path, meta_files=None, wavs_path="wav48"):
def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
test_speakers = meta_files
items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
if isinstance(test_speakers, list): # if is list ignore this speakers ids
if speaker_id in test_speakers:
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id})
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
return items


Expand Down

0 comments on commit f381e29

Please sign in to comment.