forked from Tomiinek/WaveRNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_wavernn.py
123 lines (91 loc) · 4.79 KB
/
gen_wavernn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from wavernn.utils.dsp import *
from wavernn.models.fatchord_version import WaveRNN
from wavernn.utils.paths import Paths
from wavernn.utils.display import simple_table
import torch
import argparse
from pathlib import Path
def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path: Path):
k = model.get_step() // 1000
for i, (m, x) in enumerate(test_set, 1):
if i > samples: break
print('\n| Generating: %i/%i' % (i, samples))
x = x[0].numpy()
bits = 16 if hp.voc_mode == 'MOL' else hp.bits
if hp.mu_law and hp.voc_mode != 'MOL':
x = decode_mu_law(x, 2**bits, from_labels=True)
else:
x = label_2_float(x, bits)
save_wav(x, save_path/f'{k}k_steps_{i}_target.wav')
batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED'
save_str = str(save_path/f'{k}k_steps_{i}_{batch_str}.wav')
_ = model.generate(m, save_str, batched, target, overlap, hp.mu_law)
def generate(model, spectrogram, batched, target, overlap, save_str=None):
mel = normalize(spectrogram)
if mel.ndim != 2 or mel.shape[0] != hp.num_mels:
raise ValueError(f'Expected a numpy array shaped (n_mels, n_hops), but got {wav.shape}!')
_max = np.max(mel)
_min = np.min(mel)
if _max >= 1.01 or _min <= -0.01:
raise ValueError(f'Expected spectrogram range in [0,1] but was instead [{_min}, {_max}]')
mel = torch.tensor(mel).unsqueeze(0)
return model.generate(mel, save_str, batched, target, overlap, hp.mu_law)
def gen_from_file(model: WaveRNN, load_path: Path, save_path: Path, batched, target, overlap):
mel = np.load(load_path)
generate(model, mel, batched, target, overlap, save_str=save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generate WaveRNN Samples')
parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
parser.add_argument('--samples', '-s', type=int, help='[int] number of utterances to generate')
parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
parser.add_argument('--file', '-f', type=str, required=True, help='[string/path] for testing a wav outside dataset')
parser.add_argument('--output', '-p', type=str, help='output file')
parser.add_argument('--voc_weights', '-w', type=str, help='[string/path] Load in different WaveRNN weights')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
parser.set_defaults(batched=None)
args = parser.parse_args()
hp.configure(args.hp_file) # Load hparams from file
# set defaults for any arguments that depend on hparams
if args.target is None:
args.target = hp.voc_target
if args.overlap is None:
args.overlap = hp.voc_overlap
if args.batched is None:
args.batched = hp.voc_gen_batched
if args.samples is None:
args.samples = hp.voc_gen_at_checkpoint
batched = args.batched
samples = args.samples
target = args.target
overlap = args.overlap
file = args.file
if not args.force_cpu and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print('Using device:', device)
print('\nInitialising Model...\n')
model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode).to(device)
paths = Paths(hp.data_path, hp.voc_model_id)
voc_weights = args.voc_weights if args.voc_weights else paths.voc_latest_weights
model.load(voc_weights)
simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
('Overlap Samples', overlap if batched else 'N/A')])
file = Path(file).expanduser()
gen_from_file(model, file, args.output, batched, target, overlap)
print('\n\nExiting...\n')