Skip to content

Commit

Permalink
Generate YESNO dataset on-the-fly for test (pytorch#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jul 16, 2020
1 parent 02b898f commit 102174e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 17 deletions.
Binary file removed test/assets/waves_yesno/0_1_0_1_0_1_1_0.wav
Binary file not shown.
4 changes: 2 additions & 2 deletions test/common_utils/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def supports_mp3(backend):
def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
if 'sox_io' in BACKENDS:
be = 'sox_io'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
Expand Down
12 changes: 6 additions & 6 deletions test/common_utils/test_case_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None

@property
def base_temp_dir(self):
@classmethod
def get_base_temp_dir(cls):
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
return os.environ[key]
if self.__class__.temp_dir_ is None:
self.__class__.temp_dir_ = tempfile.TemporaryDirectory()
return self.__class__.temp_dir_.name
if cls.temp_dir_ is None:
cls.temp_dir_ = tempfile.TemporaryDirectory()
return cls.temp_dir_.name

@classmethod
def tearDownClass(cls):
Expand All @@ -34,7 +34,7 @@ def tearDownClass(cls):
cls.temp_dir_ = None

def get_temp_path(self, *paths):
temp_dir = os.path.join(self.base_temp_dir, self.id())
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
path = os.path.join(temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
Expand Down
59 changes: 50 additions & 9 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest

from torchaudio.datasets.commonvoice import COMMONVOICE
Expand All @@ -10,16 +11,19 @@
from torchaudio.datasets.gtzan import GTZAN
from torchaudio.datasets.cmuarctic import CMUARCTIC

from . import common_utils
from .common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_asset_path,
get_whitenoise,
save_wav,
normalize_wav,
)


class TestDatasets(common_utils.TorchaudioTestCase):
class TestDatasets(TorchaudioTestCase):
backend = 'default'
path = common_utils.get_asset_path()

def test_yesno(self):
data = YESNO(self.path)
data[0]
path = get_asset_path()

def test_vctk(self):
data = VCTK(self.path)
Expand All @@ -46,9 +50,9 @@ def test_cmuarctic(self):
data[0]


class TestCommonVoice(common_utils.TorchaudioTestCase):
class TestCommonVoice(TorchaudioTestCase):
backend = 'default'
path = common_utils.get_asset_path()
path = get_asset_path()

def test_commonvoice(self):
data = COMMONVOICE(self.path, url="tatar")
Expand All @@ -69,5 +73,42 @@ def test_commonvoice_bg(self):
pass


class TestYesNo(TempDirMixin, TorchaudioTestCase):
backend = 'default'

root_dir = None
data = []
labels = [
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
]

@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
base_dir = os.path.join(cls.root_dir, 'waves_yesno')
os.makedirs(base_dir, exist_ok=True)
for label in cls.labels:
filename = f'{"_".join(str(l) for l in label)}.wav'
path = os.path.join(base_dir, filename)
data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16')
save_wav(path, data, 8000)
cls.data.append(normalize_wav(data))

def test_yesno(self):
dataset = YESNO(self.root_dir)
samples = list(dataset)
samples.sort(key=lambda s: s[2])
for i, (waveform, sample_rate, label) in enumerate(samples):
expected_label = self.labels[i]
expected_data = self.data[i]
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
assert sample_rate == 8000
assert label == expected_label


if __name__ == "__main__":
unittest.main()

0 comments on commit 102174e

Please sign in to comment.