Skip to content

Commit

Permalink
add extractor test
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Feb 27, 2024
1 parent 513d9d6 commit cd7986f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
6 changes: 6 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def test_dataset():
with pytest.raises(AssertionError):
d_real + d1

X1, y1 = tsgm.utils.gen_sine_vs_const_dataset(10, 20, 21, max_value=2, const=1)
y11 = np.concatenate([y1[:, None], y1[:, None]], axis=1)
d1 = tsgm.dataset.Dataset(X1, y11)
assert d1.Xy_concat.shape == (10, 20, 23)



def test_temporally_labeled_ds():
X = np.ones((10, 100, 2))
Expand Down
15 changes: 14 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import os
import tarfile
import shutil
import uuid
import functools
Expand Down Expand Up @@ -261,6 +262,7 @@ def test_reconstruction_loss_by_axis():


def test_get_physionet2012(mocker):
shutil.rmtree("./physionet2012", ignore_errors=True)
train_X, train_y, test_X, test_y, val_X, val_y = tsgm.utils.get_physionet2012()
assert train_X.shape == (1757980, 4)
assert train_y.shape == (4000, 6)
Expand Down Expand Up @@ -334,4 +336,15 @@ def test_get_covid_19():
assert X.shape[0] == len(states)
assert len(X.shape) == 3
assert X.shape[2] == 4
assert X.shape[1] >= 150
assert X.shape[1] >= 150


def test_extract_targz():
resource_folder = "./tmp/test_download/"
os.makedirs(resource_folder, exist_ok=True)
output_filename = "./tmp/dir.gz"
extracted_path = "./tmp/extracted"
with tarfile.open(output_filename, "w:gz") as tar:
tar.add(resource_folder, arcname=os.path.basename(resource_folder))
tsgm.utils.file_utils._extract_targz(output_filename, extracted_path)
assert os.path.isdir(extracted_path)

0 comments on commit cd7986f

Please sign in to comment.