Skip to content

Commit

Permalink
update acoustic feature extractor of TTS (open-mmlab#28)
Browse files Browse the repository at this point in the history
update acoustic feature extractor of TTS
  • Loading branch information
lmxue authored Dec 14, 2023
1 parent befdee8 commit f029de3
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 19 deletions.
7 changes: 4 additions & 3 deletions bins/tts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def preprocess(cfg, args):
output_path = cfg.preprocess.processed_dir
os.makedirs(output_path, exist_ok=True)

## Split train and test sets
# Split train and test sets
for dataset in cfg.dataset:
print("Preprocess {}...".format(dataset))

if args.prepare_alignment:
## Prepare alignment with MFA
# Prepare alignment with MFA
print("Prepare alignment {}...".format(dataset))
prepare_align(
dataset, cfg.dataset_path[dataset], cfg.preprocess, output_path
Expand Down Expand Up @@ -160,7 +160,7 @@ def preprocess(cfg, args):
# Dump metadata of datasets (singers, train/test durations, etc.)
cal_metadata(cfg)

## Prepare the acoustic features
# Prepare the acoustic features
for dataset in cfg.dataset:
# Skip augmented datasets which do not need to extract acoustic features
# We will copy acoustic features from the original dataset later
Expand Down Expand Up @@ -226,6 +226,7 @@ def preprocess(cfg, args):
print("Extracting content features for {}...".format(dataset))
extract_content_features(dataset, output_path, cfg, args.num_workers)


# Prepare the phenome squences
if cfg.preprocess.extract_phone:
for dataset in cfg.dataset:
Expand Down
269 changes: 264 additions & 5 deletions models/tts/base/tts_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

import json
import os

import torchaudio
import numpy as np
import torch
from utils.data_utils import *
from torch.nn.utils.rnn import pad_sequence
from text import text_to_sequence
from text.text_token_collation import phoneIDCollation
from processors.acoustic_extractor import cal_normalized_mel

from models.base.base_dataset import (
BaseDataset,
Expand All @@ -25,18 +29,272 @@


class TTSDataset(BaseDataset):
def __init__(self, args, cfg, is_valid=False):
super().__init__(args, cfg, is_valid)

def __init__(self, cfg, dataset, is_valid=False):
"""
Args:
cfg: config
dataset: dataset name
is_valid: whether to use train or valid dataset
"""

assert isinstance(dataset, str)

self.cfg = cfg

processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
self.metafile_path = os.path.join(processed_data_dir, meta_file)
self.metadata = self.get_metadata()



'''
load spk2id and utt2spk from json file
spk2id: {spk1: 0, spk2: 1, ...}
utt2spk: {dataset_uid: spk1, ...}
'''
if cfg.preprocess.use_spkid:
dataset = self.metadata[0]["Dataset"]

spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
with open(spk2id_path, "r") as f:
self.spk2id = json.load(f)

utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
self.utt2spk = dict()
with open(utt2spk_path, "r") as f:
for line in f.readlines():
utt, spk = line.strip().split('\t')
self.utt2spk[utt] = spk


if cfg.preprocess.use_uv:
self.utt2uv_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2uv_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.uv_dir,
uid + ".npy",
)

if cfg.preprocess.use_frame_pitch:
self.utt2frame_pitch_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)

self.utt2frame_pitch_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.pitch_dir,
uid + ".npy",
)

if cfg.preprocess.use_frame_energy:
self.utt2frame_energy_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)

self.utt2frame_energy_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.energy_dir,
uid + ".npy",
)

if cfg.preprocess.use_mel:
self.utt2mel_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)

self.utt2mel_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.mel_dir,
uid + ".npy",
)

if cfg.preprocess.use_linear:
self.utt2linear_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)

self.utt2linear_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.linear_dir,
uid + ".npy",
)

if cfg.preprocess.use_audio:
self.utt2audio_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)

if cfg.preprocess.extract_audio:
self.utt2audio_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.audio_dir,
uid + ".wav",
)
else:
self.utt2audio_path[utt] = utt_info["Path"]

# self.utt2audio_path[utt] = os.path.join(
# cfg.preprocess.processed_dir,
# dataset,
# cfg.preprocess.audio_dir,
# uid + ".numpy",
# )

elif cfg.preprocess.use_label:
self.utt2label_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)

self.utt2label_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.label_dir,
uid + ".npy",
)
elif cfg.preprocess.use_one_hot:
self.utt2one_hot_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)

self.utt2one_hot_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.one_hot_dir,
uid + ".npy",
)

if cfg.preprocess.use_text or cfg.preprocess.use_phone:
self.utt2seq = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)

if cfg.preprocess.use_text:
text = utt_info["Text"]
sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
elif cfg.preprocess.use_phone:
# load phoneme squence from phone file
phone_path = os.path.join(processed_data_dir,
cfg.preprocess.phone_dir,
uid+'.phone'
)
with open(phone_path, 'r') as fin:
phones = fin.readlines()
assert len(phones) == 1
phones = phones[0].strip()
phones_seq = phones.split(' ')

phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)

self.utt2seq[utt] = sequence

def __getitem__(self, index):
single_feature = super().__getitem__(index)
utt_info = self.metadata[index]

dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)

single_feature = dict()

if self.cfg.preprocess.use_spkid:
single_feature["spk_id"] = np.array(
[self.spk2id[self.utt2spk[utt]]], dtype=np.int32
)

if self.cfg.preprocess.use_mel:
mel = np.load(self.utt2mel_path[utt])
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
if self.cfg.preprocess.use_min_max_norm_mel:
# do mel norm
mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)

if "target_len" not in single_feature.keys():
single_feature["target_len"] = mel.shape[1]
single_feature["mel"] = mel.T # [T, n_mels]

if self.cfg.preprocess.use_linear:
linear = np.load(self.utt2linear_path[utt])
if "target_len" not in single_feature.keys():
single_feature["target_len"] = linear.shape[1]
single_feature["linear"] = linear.T # [T, n_linear]

if self.cfg.preprocess.use_frame_pitch:
frame_pitch_path = self.utt2frame_pitch_path[utt]
frame_pitch = np.load(frame_pitch_path)
if "target_len" not in single_feature.keys():
single_feature["target_len"] = len(frame_pitch)
aligned_frame_pitch = align_length(
frame_pitch, single_feature["target_len"]
)
single_feature["frame_pitch"] = aligned_frame_pitch

if self.cfg.preprocess.use_uv:
frame_uv_path = self.utt2uv_path[utt]
frame_uv = np.load(frame_uv_path)
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
aligned_frame_uv = [
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
]
aligned_frame_uv = np.array(aligned_frame_uv)
single_feature["frame_uv"] = aligned_frame_uv

if self.cfg.preprocess.use_frame_energy:
frame_energy_path = self.utt2frame_energy_path[utt]
frame_energy = np.load(frame_energy_path)
if "target_len" not in single_feature.keys():
single_feature["target_len"] = len(frame_energy)
aligned_frame_energy = align_length(
frame_energy, single_feature["target_len"]
)
single_feature["frame_energy"] = aligned_frame_energy

if self.cfg.preprocess.use_audio:
audio, sr = torchaudio.load(self.utt2audio_path[utt])
audio = audio.cpu().numpy().squeeze()
single_feature["audio"] = audio
single_feature["audio_len"] = audio.shape[0]


if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
single_feature["phone_seq"] = np.array(self.utt2seq[utt])
single_feature["phone_len"] = len(self.utt2seq[utt])

return single_feature

def __len__(self):

return super().__len__()

def get_metadata(self):
return super().get_metadata()

class TTSCollator(BaseCollator):
"""Zero-pads model inputs and targets based on number of frames per step"""
Expand Down Expand Up @@ -86,6 +344,7 @@ def __len__(self):
return len(self.metadata)



class TTSTestCollator(BaseTestCollator):
"""Zero-pads model inputs and targets based on number of frames per step"""

Expand Down
Loading

0 comments on commit f029de3

Please sign in to comment.