forked from p0p4k/vits2_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_onnx.py
110 lines (87 loc) · 2.99 KB
/
export_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
from pathlib import Path
from typing import Optional
import torch
import utils
from models import SynthesizerTrn
from text.symbols import symbols
OPSET_VERSION = 15
def main() -> None:
torch.manual_seed(1234)
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path", required=True, help="Path to model weights (.pth)"
)
parser.add_argument(
"--config-path", required=True, help="Path to model config (.json)"
)
parser.add_argument("--output", required=True, help="Path to output model (.onnx)")
args = parser.parse_args()
args.model_path = Path(args.model_path)
args.config_path = Path(args.config_path)
args.output = Path(args.output)
args.output.parent.mkdir(parents=True, exist_ok=True)
hps = utils.get_hparams_from_file(args.config_path)
if (
"use_mel_posterior_encoder" in hps.model.keys()
and hps.model.use_mel_posterior_encoder == True
):
print("Using mel posterior encoder for VITS2")
posterior_channels = 80 # vits2
hps.data.use_mel_posterior_encoder = True
else:
print("Using lin posterior encoder for VITS1")
posterior_channels = hps.data.filter_length // 2 + 1
hps.data.use_mel_posterior_encoder = False
model_g = SynthesizerTrn(
len(symbols),
posterior_channels,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
_ = model_g.eval()
_ = utils.load_checkpoint(args.model_path, model_g, None)
def infer_forward(text, text_lengths, scales, sid=None):
noise_scale = scales[0]
length_scale = scales[1]
noise_scale_w = scales[2]
audio = model_g.infer(
text,
text_lengths,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
sid=sid,
)[0]
return audio
model_g.forward = infer_forward
dummy_input_length = 50
sequences = torch.randint(
low=0, high=len(symbols), size=(1, dummy_input_length), dtype=torch.long
)
sequence_lengths = torch.LongTensor([sequences.size(1)])
sid: Optional[torch.LongTensor] = None
if hps.data.n_speakers > 1:
sid = torch.LongTensor([0])
# noise, length, noise_w
scales = torch.FloatTensor([0.667, 1.0, 0.8])
dummy_input = (sequences, sequence_lengths, scales, sid)
# Export
torch.onnx.export(
model=model_g,
args=dummy_input,
f=str(args.output),
verbose=False,
opset_version=OPSET_VERSION,
input_names=["input", "input_lengths", "scales", "sid"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "phonemes"},
"input_lengths": {0: "batch_size"},
"output": {0: "batch_size", 1: "time1", 2: "time2"},
},
)
print(f"Exported model to {args.output}")
if __name__ == "__main__":
main()