Skip to content

Commit

Permalink
Add tests for make_splits
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippThoelke committed Sep 2, 2021
1 parent 7743556 commit e9082c1
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
58 changes: 58 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pytest import mark, raises
import torch
from torchmdnet.utils import make_splits


def sum_lengths(*args):
return sum(map(len, args))


def test_make_splits_outputs():
result = make_splits(100, 0.7, 0.2, 0.1, 1234)
assert len(result) == 3
assert isinstance(result[0], torch.Tensor)
assert isinstance(result[1], torch.Tensor)
assert isinstance(result[2], torch.Tensor)
assert sum_lengths(*result) == len(torch.unique(torch.cat(result)))
assert max(map(max, result)) == 99
assert min(map(min, result)) == 0


@mark.parametrize("dset_len", [5, 1000])
@mark.parametrize("ratio1", [0.0, 0.3])
@mark.parametrize("ratio2", [0.0, 0.3])
@mark.parametrize("ratio3", [0.0, 0.3])
def test_make_splits_ratios(dset_len, ratio1, ratio2, ratio3):
train, val, test = make_splits(dset_len, ratio1, ratio2, ratio3, 1234)
assert sum_lengths(train, val, test) <= dset_len
assert len(train) == round(ratio1 * dset_len)
assert len(val) == round(ratio2 * dset_len)
# simply multiplying and rounding ratios can lead to values larger than dset_len,
# which make_splits should account for by removing one sample from the test set
if (
round(ratio1 * dset_len) + round(ratio2 * dset_len) + round(ratio3 * dset_len)
> dset_len
):
assert len(test) == round(ratio3 * dset_len) - 1
else:
assert len(test) == round(ratio3 * dset_len)


def test_make_splits_sizes():
assert sum_lengths(*make_splits(100, 70, 20, 10, 1234)) == 100
assert sum_lengths(*make_splits(100, 70, 20, None, 1234)) == 100
assert sum_lengths(*make_splits(100, 70, None, 10, 1234)) == 100
assert sum_lengths(*make_splits(100, None, 20, 10, 1234)) == 100
assert sum_lengths(*make_splits(100, 70, 20, 0.1, 1234)) == 100
assert sum_lengths(*make_splits(100, 70, 20, 0.05, 1234)) == 95


def test_make_splits_errors():
with raises(AssertionError):
make_splits(100, 0.5, 0.5, 0.5, 1234)
with raises(AssertionError):
make_splits(100, 50, 50, 50, 1234)
with raises(AssertionError):
make_splits(100, None, None, 5, 1234)
with raises(AssertionError):
make_splits(100, 60, 60, None, 1234)
26 changes: 18 additions & 8 deletions torchmdnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,38 @@
import numpy as np
import torch
from os.path import dirname, join, exists
from pytorch_lightning.utilities import rank_zero_warn


def train_val_test_split(dset_len, train_size, val_size, test_size, seed, order=None):
assert (train_size is None) + (val_size is None) + (
test_size is None
) <= 1, "Only one of train_size, val_size, test_size is allowed to be None."

train_size = (
round(dset_len * train_size) if isinstance(train_size, float) else train_size
)
val_size = round(dset_len * val_size) if isinstance(val_size, float) else val_size
test_size = (
round(dset_len * test_size) if isinstance(test_size, float) else test_size
is_float = (
isinstance(train_size, float),
isinstance(val_size, float),
isinstance(test_size, float),
)

train_size = round(dset_len * train_size) if is_float[0] else train_size
val_size = round(dset_len * val_size) if is_float[1] else val_size
test_size = round(dset_len * test_size) if is_float[2] else test_size

if train_size is None:
train_size = dset_len - val_size - test_size
elif val_size is None:
val_size = dset_len - train_size - test_size
elif test_size is None:
test_size = dset_len - train_size - val_size

if train_size + val_size + test_size > dset_len:
if is_float[2]:
test_size -= 1
elif is_float[1]:
val_size -= 1
elif is_float[0]:
train_size -= 1

assert train_size >= 0 and val_size >= 0 and test_size >= 0, (
f"One of training ({train_size}), validation ({val_size}) or "
f"testing ({test_size}) splits ended up with a negative size."
Expand All @@ -36,7 +46,7 @@ def train_val_test_split(dset_len, train_size, val_size, test_size, seed, order=
f"combined split sizes ({total})."
)
if total < dset_len:
print(f"Warning: {dset_len - total} samples were excluded from the dataset")
rank_zero_warn(f"{dset_len - total} samples were excluded from the dataset")

idxs = np.arange(dset_len, dtype=np.int)
if order is None:
Expand Down

0 comments on commit e9082c1

Please sign in to comment.