forked from p0p4k/vits2_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_onnx.py
58 lines (45 loc) · 1.6 KB
/
infer_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import onnxruntime
import numpy as np
import argparse
import commons
import utils
from text import text_to_sequence
from scipy.io.wavfile import write
def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Path to model (.onnx)")
parser.add_argument(
"--config-path", required=True, help="Path to model config (.json)"
)
parser.add_argument(
"--output-wav-path", required=True, help="Path to write WAV file"
)
parser.add_argument("--text", required=True, type=str, help="Text to synthesize")
args = parser.parse_args()
sess_options = onnxruntime.SessionOptions()
model = onnxruntime.InferenceSession(str(args.model), sess_options=sess_options)
hps = utils.get_hparams_from_file(args.config_path)
phoneme_ids = get_text(args.text, hps)
text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0)
text_lengths = np.array([text.shape[1]], dtype=np.int64)
scales = np.array([0.667, 1.0, 0.8], dtype=np.float32)
sid = None
audio = model.run(
None,
{
"input": text,
"input_lengths": text_lengths,
"scales": scales,
"sid": sid,
},
)[0].squeeze((0, 1))
write(data=audio, rate=hps.data.sampling_rate, filename=args.output_wav_path)
if __name__ == "__main__":
main()