Skip to content

Commit

Permalink
feat: Dataset from H5 file for SE
Browse files Browse the repository at this point in the history
  • Loading branch information
santi-pdp committed Nov 12, 2018
1 parent 9d4c91a commit ad6ab7a
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions segan/datasets/se_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import librosa
from ahoproc_tools.io import *
from ahoproc_tools.interpolate import *
import h5py


def collate_fn(batch):
Expand Down Expand Up @@ -523,17 +524,60 @@ def __getitem__(self, index):
def __len__(self):
return len(self.samples)

class SEH5Dataset(Dataset):
""" Speech enhancement dataset from H5 data file.
The pairs must be named (data, label), being each
one a dataset containing wav chunks (already chunked
to fixed size).
"""
def __init__(self, data_root, split, preemph,
max_samples=None, verbose=False,
preemph_norm=False,
random_scale=[1]):
super().__init__()
self.data_root = data_root
self.split = split
self.preemph = preemph
self.max_samples = max_samples
self.verbose = verbose
self.random_scale = random_scale
h5_file = os.path.join(data_root, split + '.h5')
if not os.path.exists(h5_file):
raise FileNotFoundError(h5_file)
f = h5py.File(h5_file, 'r')
ks = list(f.keys())
assert 'data' in ks, ks
assert 'label' in ks, ks
if verbose:
print('Found H5 file {} with {} samples'.format(h5_file,
f['data'].shape[0]))
self.f = f

def __getitem__(self, index):
c_slice = self.f['data'][index]
n_slice = self.f['label'][index]
rscale = random.choice(self.random_scale)
if rscale != 1:
c_slice = rscale * c_slice
n_slice = rscale * n_slice
# uttname not known with H5
returns = ['N/A', torch.FloatTensor(c_slice),
torch.FloatTensor(n_slice), 0]
return returns

def __len__(self):
return self.f['data'].shape[0]

if __name__ == '__main__':
#dset = SEDataset('../../data/clean_trainset', '../../data/noisy_trainset', 0.95,
# cache_dir=None, max_samples=100, verbose=True)
#sample_0 = dset.__getitem__(0)
#print('sample_0: ', sample_0)

dset = RandomChunkSEF0Dataset('../../data/silent/clean_trainset',
'../../data/silent/lf0_trainset', 0.)
sample_0 = dset.__getitem__(0)
print('len sample_0: ', len(sample_0))
for data in sample_0:
if isinstance(data, str):
continue
print(data.size())
#dset = RandomChunkSEF0Dataset('../../data/silent/clean_trainset',
# '../../data/silent/lf0_trainset', 0.)
#sample_0 = dset.__getitem__(0)
#print('len sample_0: ', len(sample_0))
dset = SEH5Dataset('../../data/widebandnet_h5/speaker1', 'train',
0.95, verbose=True)
print(len(dset))

0 comments on commit ad6ab7a

Please sign in to comment.