Skip to content

Commit

Permalink
[FastPitch/PyT] Fix handling heteronyms when training with a lexicone
Browse files Browse the repository at this point in the history
  • Loading branch information
alancucki authored and nv-kkudrynski committed Feb 17, 2022
1 parent bd4cd21 commit 37a5e77
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
15 changes: 8 additions & 7 deletions PyTorch/SpeechSynthesis/FastPitch/common/text/cmudict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,9 @@ def __init__(self, file_or_path=None, heteronyms_path=None, keep_ambiguous=True)
if file_or_path is None:
self._entries = {}
else:
self.initialize(file_or_path, keep_ambiguous)
self.initialize(file_or_path, heteronyms_path, keep_ambiguous)

if heteronyms_path is None:
self.heteronyms = []
else:
self.heteronyms = set(lines_to_list(heteronyms_path))

def initialize(self, file_or_path, keep_ambiguous=True):
def initialize(self, file_or_path, heteronyms_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
try:
with open(file_or_path, encoding='latin-1') as f:
Expand All @@ -55,6 +50,12 @@ def initialize(self, file_or_path, keep_ambiguous=True):
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
self._entries = entries

if heteronyms_path is None:
self.heteronyms = []
else:
self.heteronyms = set(lines_to_list(heteronyms_path))


def __len__(self):
if len(self._entries) == 0:
raise ValueError("CMUDict not initialized")
Expand Down
2 changes: 1 addition & 1 deletion PyTorch/SpeechSynthesis/FastPitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def main():
args, unk_args = parser.parse_known_args()

if args.p_arpabet > 0.0:
cmudict.initialize(args.cmudict_path, keep_ambiguous=True)
cmudict.initialize(args.cmudict_path, args.heteronyms_path)

torch.backends.cudnn.benchmark = args.cudnn_benchmark

Expand Down
2 changes: 1 addition & 1 deletion PyTorch/SpeechSynthesis/FastPitch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def main():
args, _ = parser.parse_known_args()

if args.p_arpabet > 0.0:
cmudict.initialize(args.cmudict_path, keep_ambiguous=True)
cmudict.initialize(args.cmudict_path, args.heteronyms_path)

distributed_run = args.world_size > 1

Expand Down
2 changes: 1 addition & 1 deletion PyTorch/SpeechSynthesis/FastPitch/triton/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def main():

if args.p_arpabet > 0.0:
from common.text import cmudict
cmudict.initialize(args.cmudict_path, keep_ambiguous=True)
cmudict.initialize(args.cmudict_path, args.heteronyms_path)

get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
dataloader_fn = ArgParserGenerator(get_dataloader_fn).from_args(args)
Expand Down
2 changes: 1 addition & 1 deletion PyTorch/SpeechSynthesis/FastPitch/triton/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_dataloader_fn(batch_size: int = 8,
mel_fmax: float = 8000.0):

if p_arpabet > 0.0:
cmudict.initialize(cmudict_path, keep_ambiguous=True)
cmudict.initialize(cmudict_path, heteronyms_path)

dataset = TTSDataset(dataset_path=dataset_path,
audiopaths_and_text=filelist,
Expand Down

0 comments on commit 37a5e77

Please sign in to comment.