forked from ming024/FastSpeech2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fastspeech2.py
110 lines (96 loc) · 2.95 KB
/
fastspeech2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformer import Encoder, Decoder, PostNet
from .modules import VarianceAdaptor
from utils.tools import get_mask_from_lengths
class FastSpeech2(nn.Module):
""" FastSpeech2 """
def __init__(self, preprocess_config, model_config):
super(FastSpeech2, self).__init__()
self.model_config = model_config
self.encoder = Encoder(model_config)
self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
self.decoder = Decoder(model_config)
self.mel_linear = nn.Linear(
model_config["transformer"]["decoder_hidden"],
preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
)
self.postnet = PostNet()
self.speaker_emb = None
if model_config["multi_speaker"]:
with open(
os.path.join(
preprocess_config["path"]["preprocessed_path"], "speakers.json"
),
"r",
) as f:
n_speaker = len(json.load(f))
self.speaker_emb = nn.Embedding(
n_speaker,
model_config["transformer"]["encoder_hidden"],
)
def forward(
self,
speakers,
texts,
src_lens,
max_src_len,
mels=None,
mel_lens=None,
max_mel_len=None,
p_targets=None,
e_targets=None,
d_targets=None,
p_control=1.0,
e_control=1.0,
d_control=1.0,
):
src_masks = get_mask_from_lengths(src_lens, max_src_len)
mel_masks = (
get_mask_from_lengths(mel_lens, max_mel_len)
if mel_lens is not None
else None
)
output = self.encoder(texts, src_masks)
if self.speaker_emb is not None:
output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
-1, max_src_len, -1
)
(
output,
p_predictions,
e_predictions,
log_d_predictions,
d_rounded,
mel_lens,
mel_masks,
) = self.variance_adaptor(
output,
src_masks,
mel_masks,
max_mel_len,
p_targets,
e_targets,
d_targets,
p_control,
e_control,
d_control,
)
output, mel_masks = self.decoder(output, mel_masks)
output = self.mel_linear(output)
postnet_output = self.postnet(output) + output
return (
output,
postnet_output,
p_predictions,
e_predictions,
log_d_predictions,
d_rounded,
src_masks,
mel_masks,
src_lens,
mel_lens,
)