Skip to content

Commit

Permalink
Add config file for pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
babysor committed Feb 23, 2022
1 parent 4529479 commit 0536874
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
13 changes: 13 additions & 0 deletions synthesizer/hparams.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import pprint
import json

class HParams(object):
def __init__(self, **kwargs): self.__dict__.update(kwargs)
Expand All @@ -18,6 +19,18 @@ def parse(self, string):
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
return self

def loadJson(self, dict):
print("\Loading the json with %s\n", dict)
for k in dict.keys():
self.__dict__[k] = dict[k]
return self

def dumpJson(self, fp):
print("\Saving the json with %s\n", fp)
with fp.open("w", encoding="utf-8") as f:
json.dump(self.__dict__, f)
return self

hparams = HParams(
### Signal Processing (used in both synthesizer and vocoder)
sample_rate = 16000,
Expand Down
6 changes: 6 additions & 0 deletions synthesizer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import librosa
from utils import logmmse
import json
from pypinyin import lazy_pinyin, Style

class Synthesizer:
Expand Down Expand Up @@ -44,6 +45,11 @@ def is_loaded(self):
return self._model is not None

def load(self):
# Try to scan config file
model_config_fpaths = list(self.model_fpath.parent.rglob("*.json"))
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
hparams.loadJson(json.load(f))
"""
Instantiates and loads the model given the weights file that was passed in the constructor.
"""
Expand Down
8 changes: 8 additions & 0 deletions synthesizer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from synthesizer.utils.text import sequence_to_text
from vocoder.display import *
from datetime import datetime
import json
import numpy as np
from pathlib import Path
import sys
Expand Down Expand Up @@ -75,6 +76,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
if num_chars != loaded_shape[0]:
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
num_chars != loaded_shape[0]
# Try to scan config file
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
hparams.loadJson(json.load(f))
else: # save a config
hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json"))


model = Tacotron(embed_dims=hparams.tts_embed_dims,
Expand Down

0 comments on commit 0536874

Please sign in to comment.