forked from legoodmanner/jukedrummer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
111 lines (92 loc) · 3.67 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
import pickle
import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
def compute_mean_std(mel_dir, pkl=None):
# Compute means and standards of Mel spectrograms for every Mel-filter bank before normalization
# Input:
# mel_dir: an absolute path of Mel data directory
# pkl: Or, filenames of .pkl file is also accepted
if pkl is not None:
in_fns = pkl[0]
else:
in_fns = os.listdir(mel_dir)
in_fns = [fn for fn in in_fns if '.npy' in fn]
scaler = StandardScaler()
pbar = tqdm(in_fns, dynamic_ncols=True,)
non_nan = []
print('computing mean and std ...')
for fn in pbar:
in_fp = os.path.join(mel_dir, fn)
data = np.load(in_fp).T
if np.isnan(data).any():
print(fn)
continue
non_nan += [fn]
scaler.partial_fit(data)
if True in np.isnan(scaler.scale_):
break
mean = scaler.mean_
std = scaler.scale_
return torch.FloatTensor(mean).view(1, 80, 1), torch.FloatTensor(std).view(1, 80, 1), non_nan
class BeatInfoPairedDataset(Dataset):
def __init__(self, fl, hps, return_fn=False):
super().__init__()
self.fl = fl
self.root = hps.path
self.binfo_type = hps.binfo_type
self.vq_name = hps.vq_name
self.return_fn = return_fn
def __getitem__(self, idx):
fname = self.fl[idx]
tg_token = np.load(os.path.join(self.root, 'token', 'target', self.vq_name, fname))
ot_token = np.load(os.path.join(self.root, 'token', 'others', self.vq_name, fname))
if self.binfo_type is None:
ot_binfo = np.load(os.path.join(self.root, 'beats', 'low', fname))
else:
ot_binfo = np.load(os.path.join(self.root, 'beats', self.binfo_type, fname))
if self.return_fn:
return tg_token.squeeze(), ot_token.squeeze(), ot_binfo, fname
else:
return tg_token.squeeze(), ot_token.squeeze(), ot_binfo
def __len__(self):
return len(self.fl)
class MelDataset(Dataset):
def __init__(self, fl, hps, data_type):
super().__init__()
self.fl = fl
self.root = hps.path
self.data_type = data_type
def __getitem__(self, idx):
fname = self.fl[idx]
if not fname.endswith('.npy'):
fname = fname + '.npy'
item = np.load(os.path.join(self.root, 'mel', self.data_type, fname))
return item
def __len__(self):
return len(self.fl)
from utils.functions import mel2token, wav2mel
class End2EndWrapper(Dataset):
def __init__(self, input_dir, vqvae, beat_extractor, mel_extractor, others_mean, others_std, device):
super().__init__()
self.mel_extractor = mel_extractor
self.beat_extractor = beat_extractor
self.others_mean, self.others_std = others_mean, others_std
self.vqvae = vqvae
self.device = device
fns = os.listdir(input_dir)
self.dpaths = [os.path.join(input_dir,f) for f in fns if f.endswith('.wav')]
def __getitem__(self, index):
beat_info = self.beat_extractor(self.dpaths[index])
beat_info = torch.from_numpy(beat_info).unsqueeze(0).to(self.device) if not np.isnan(beat_info).any() else None
mel = wav2mel(self.dpaths[index], self.mel_extractor)
t = mel2token(mel, self.vqvae, self.others_mean, self.others_std, self.device)
t = torch.from_numpy(t).long().unsqueeze(0).to(self.device)
return t, beat_info, self.dpaths[index].split('/')[-1]
def __len__(self,):
return len(self.dpaths)