From 102174e968e6365a527e34ebffc31d5c674eff0c Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 16 Jul 2020 18:54:22 -0400 Subject: [PATCH] Generate YESNO dataset on-the-fly for test (#792) --- test/assets/waves_yesno/0_1_0_1_0_1_1_0.wav | Bin 84 -> 0 bytes test/common_utils/backend_utils.py | 4 +- test/common_utils/test_case_utils.py | 12 ++-- test/test_datasets.py | 59 +++++++++++++++++--- 4 files changed, 58 insertions(+), 17 deletions(-) delete mode 100644 test/assets/waves_yesno/0_1_0_1_0_1_1_0.wav diff --git a/test/assets/waves_yesno/0_1_0_1_0_1_1_0.wav b/test/assets/waves_yesno/0_1_0_1_0_1_1_0.wav deleted file mode 100644 index 66ee46737e6fb4a50e41780fea3c0b01b45830b6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 84 zcmWIYbaV4zU|L!fNF0&}04a|ce*gdg diff --git a/test/common_utils/backend_utils.py b/test/common_utils/backend_utils.py index 158fde87ed..beceb6cafb 100644 --- a/test/common_utils/backend_utils.py +++ b/test/common_utils/backend_utils.py @@ -29,8 +29,8 @@ def supports_mp3(backend): def set_audio_backend(backend): """Allow additional backend value, 'default'""" if backend == 'default': - if 'sox' in BACKENDS: - be = 'sox' + if 'sox_io' in BACKENDS: + be = 'sox_io' elif 'soundfile' in BACKENDS: be = 'soundfile' else: diff --git a/test/common_utils/test_case_utils.py b/test/common_utils/test_case_utils.py index 253e2166fb..9c89ef6891 100644 --- a/test/common_utils/test_case_utils.py +++ b/test/common_utils/test_case_utils.py @@ -15,16 +15,16 @@ class TempDirMixin: """Mixin to provide easy access to temp dir""" temp_dir_ = None - @property - def base_temp_dir(self): + @classmethod + def get_base_temp_dir(cls): # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. # this is handy for debugging. key = 'TORCHAUDIO_TEST_TEMP_DIR' if key in os.environ: return os.environ[key] - if self.__class__.temp_dir_ is None: - self.__class__.temp_dir_ = tempfile.TemporaryDirectory() - return self.__class__.temp_dir_.name + if cls.temp_dir_ is None: + cls.temp_dir_ = tempfile.TemporaryDirectory() + return cls.temp_dir_.name @classmethod def tearDownClass(cls): @@ -34,7 +34,7 @@ def tearDownClass(cls): cls.temp_dir_ = None def get_temp_path(self, *paths): - temp_dir = os.path.join(self.base_temp_dir, self.id()) + temp_dir = os.path.join(self.get_base_temp_dir(), self.id()) path = os.path.join(temp_dir, *paths) os.makedirs(os.path.dirname(path), exist_ok=True) return path diff --git a/test/test_datasets.py b/test/test_datasets.py index c3b0c917da..4efc816257 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1,3 +1,4 @@ +import os import unittest from torchaudio.datasets.commonvoice import COMMONVOICE @@ -10,16 +11,19 @@ from torchaudio.datasets.gtzan import GTZAN from torchaudio.datasets.cmuarctic import CMUARCTIC -from . import common_utils +from .common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_asset_path, + get_whitenoise, + save_wav, + normalize_wav, +) -class TestDatasets(common_utils.TorchaudioTestCase): +class TestDatasets(TorchaudioTestCase): backend = 'default' - path = common_utils.get_asset_path() - - def test_yesno(self): - data = YESNO(self.path) - data[0] + path = get_asset_path() def test_vctk(self): data = VCTK(self.path) @@ -46,9 +50,9 @@ def test_cmuarctic(self): data[0] -class TestCommonVoice(common_utils.TorchaudioTestCase): +class TestCommonVoice(TorchaudioTestCase): backend = 'default' - path = common_utils.get_asset_path() + path = get_asset_path() def test_commonvoice(self): data = COMMONVOICE(self.path, url="tatar") @@ -69,5 +73,42 @@ def test_commonvoice_bg(self): pass +class TestYesNo(TempDirMixin, TorchaudioTestCase): + backend = 'default' + + root_dir = None + data = [] + labels = [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1], + [0, 1, 0, 1, 0, 1, 1, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1], + ] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + base_dir = os.path.join(cls.root_dir, 'waves_yesno') + os.makedirs(base_dir, exist_ok=True) + for label in cls.labels: + filename = f'{"_".join(str(l) for l in label)}.wav' + path = os.path.join(base_dir, filename) + data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16') + save_wav(path, data, 8000) + cls.data.append(normalize_wav(data)) + + def test_yesno(self): + dataset = YESNO(self.root_dir) + samples = list(dataset) + samples.sort(key=lambda s: s[2]) + for i, (waveform, sample_rate, label) in enumerate(samples): + expected_label = self.labels[i] + expected_data = self.data[i] + self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) + assert sample_rate == 8000 + assert label == expected_label + + if __name__ == "__main__": unittest.main()