Skip to content

Commit

Permalink
Implement bucketed weighted sampling for VITS (coqui-ai#1871)
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol authored Aug 15, 2022
1 parent d46fbc2 commit bfc6382
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 28 deletions.
2 changes: 1 addition & 1 deletion TTS/bin/train_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@

from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update

torch.backends.cudnn.enabled = True
Expand Down
17 changes: 17 additions & 0 deletions TTS/tts/configs/vits_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
use_weighted_sampler (bool):
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
weighted_sampler_attrs (dict):
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
by overweighting `root_path` by 2.0. Defaults to `{}`.
weighted_sampler_multipliers (dict):
Weight each unique value of a key returned by the formatter for weighted sampling.
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
Expand Down Expand Up @@ -124,6 +136,11 @@ class VitsConfig(BaseTTSConfig):
return_wav: bool = True
compute_linear_spec: bool = True

# sampler params
use_weighted_sampler: bool = False # TODO: move it to the base config
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})

# overrides
r: int = 1 # DO NOT CHANGE
add_blank: bool = True
Expand Down
64 changes: 43 additions & 21 deletions TTS/tts/datasets/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
"audio_file": audio_path,
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
"root_path": root_path,
}
)
if not_found_counter > 0:
Expand All @@ -53,7 +54,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav")
text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items


Expand All @@ -68,7 +69,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
wav_file = cols[1].strip()
text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file)
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items


Expand All @@ -84,7 +85,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file)
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items


Expand Down Expand Up @@ -130,7 +131,9 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
if os.path.isfile(wav_file):
text = cols[1].strip()
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}
)
else:
# M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file))
Expand All @@ -148,7 +151,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items


Expand All @@ -166,7 +169,9 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path}
)
return items


Expand All @@ -181,7 +186,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items


Expand All @@ -198,7 +203,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items


Expand All @@ -213,7 +218,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items


Expand Down Expand Up @@ -261,7 +266,9 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
if speaker_name in ignored_speakers:
continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path}
)
return items


Expand All @@ -288,7 +295,14 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
items.append(
{
"text": text,
"audio_file": wav_file,
"speaker_name": f"LTTS_{speaker_name}",
"root_path": root_path,
}
)
for item in items:
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
return items
Expand All @@ -307,7 +321,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
skipped_files.append(wav_file)
continue
text = cols[1].strip()
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
return items

Expand All @@ -329,7 +343,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path})
return items


Expand Down Expand Up @@ -372,7 +386,9 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
else:
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
if os.path.exists(wav_file):
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
)
else:
print(f" [!] wav files don't exist - {wav_file}")
return items
Expand All @@ -392,7 +408,9 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path}
)
return items


Expand All @@ -411,7 +429,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar
if os.path.exists(txt_file) and os.path.exists(wav_file):
with open(txt_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items


Expand All @@ -433,7 +451,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno
if ignore_digits_sentences and any(map(str.isdigit, text)):
continue
wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac")
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id})
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path})
return items


Expand All @@ -450,7 +468,9 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker in ignored_speakers:
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path}
)
return items


Expand Down Expand Up @@ -520,7 +540,9 @@ def emotion(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
items.append(
{"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path}
)
return items


Expand All @@ -540,7 +562,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
for line in ttf:
wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name)
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path})
return items


Expand All @@ -554,5 +576,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2].replace(" ", "")
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
72 changes: 67 additions & 5 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import chain
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
import torchaudio
Expand All @@ -13,6 +14,8 @@
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler

from TTS.tts.configs.shared_configs import CharactersConfig
Expand All @@ -29,6 +32,8 @@
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.io import load_fsspec
from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results

Expand Down Expand Up @@ -221,6 +226,30 @@ class VitsAudioConfig(Coqpit):
##############################


def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
"""Create inverse frequency weights for balancing the dataset.
Use `multi_dict` to scale relative weights."""
attr_names_samples = np.array([item[attr_name] for item in items])
unique_attr_names = np.unique(attr_names_samples).tolist()
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
weight_attr = 1.0 / attr_count
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
if multi_dict is not None:
# check if all keys are in the multi_dict
for k in multi_dict:
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
# scale weights
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
dataset_samples_weight *= multiplier_samples
return (
torch.from_numpy(dataset_samples_weight).float(),
unique_attr_names,
np.unique(dataset_samples_weight).tolist(),
)


class VitsDataset(TTSDataset):
def __init__(self, model_args, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -1510,6 +1539,42 @@ def format_batch_on_device(self, batch):
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
return batch

def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
weights = None
data_items = dataset.samples
if getattr(config, "use_weighted_sampler", False):
for attr_name, alpha in config.weighted_sampler_attrs.items():
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
print(multi_dict)
weights, attr_names, attr_weights = get_attribute_balancer_weights(
attr_name=attr_name, items=data_items, multi_dict=multi_dict
)
weights = weights * alpha
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")

# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]

if weights is not None:
w_sampler = WeightedRandomSampler(weights, len(weights))
batch_sampler = BucketBatchSampler(
w_sampler,
data=data_items,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
sort_key=lambda x: os.path.getsize(x["audio_file"]),
drop_last=True,
)
else:
batch_sampler = None
# sampler for DDP
if batch_sampler is None:
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
else: # If a sampler is already defined use this sampler and DDP sampler together
batch_sampler = (
DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler
) # TODO: check batch_sampler with multi-gpu
return batch_sampler

def get_data_loader(
self,
config: Coqpit,
Expand Down Expand Up @@ -1551,10 +1616,7 @@ def get_data_loader(

loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False, # shuffle is done in the dataset.
drop_last=False, # setting this False might cause issues in AMP training.
sampler=sampler,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
Expand Down Expand Up @@ -1615,7 +1677,7 @@ def load_checkpoint(
strict=True,
): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# compat band-aid for the pre-trained models to not use the encoder baked into the model
# TODO: consider baking the speaker encoder into the model and call it from there.
# as it is probably easier for model distribution.
Expand Down
Loading

0 comments on commit bfc6382

Please sign in to comment.