Skip to content

Commit

Permalink
[MNT] Testing fixes (#2531)
Browse files Browse the repository at this point in the history
* adjust test for non numpy output

* test list output

* test dataframe output

* change pickle test

* equal nans

* test scalar output

* fix lists output

* allow arrays of objects

* allow arrays of objects

* test for boolean elements (MERLIN)

* switch to deep equals

* switch to deep equals

* switch to deep equals

* message

* testing fixes

---------

Co-authored-by: Tony Bagnall <[email protected]>
  • Loading branch information
MatthewMiddlehurst and TonyBagnall authored Feb 8, 2025
1 parent a3cbff0 commit dfc3aed
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 14 deletions.
14 changes: 6 additions & 8 deletions aeon/testing/estimator_checking/_yield_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,10 +627,9 @@ def check_persistence_via_pickle(estimator, datatype):
same, msg = deep_equals(output, results[i], return_msg=True)
if not same:
raise ValueError(
f"Running {method} after serialisation parameters gives "
f"different results. "
f"{type(estimator)} returns data as {type(output)}: test "
f"equivalence message: {msg}"
f"Running {type(estimator)} {method} with test parameters after "
f"serialisation gives different results. "
f"Check equivalence message: {msg}"
)
i += 1

Expand All @@ -657,9 +656,8 @@ def check_fit_deterministic(estimator, datatype):
same, msg = deep_equals(output, results[i], return_msg=True)
if not same:
raise ValueError(
f"Running {method} with test parameters after two calls to fit "
f"gives different results."
f"{type(estimator)} returns data as {type(output)}: test "
f"equivalence message: {msg}"
f"Running {type(estimator)} {method} with test parameters after "
f"two calls to fit gives different results."
f"Check equivalence message: {msg}"
)
i += 1
7 changes: 5 additions & 2 deletions aeon/testing/testing_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
NUMBA_DISABLED = os.environ.get("NUMBA_DISABLE_JIT") == "1"

# exclude estimators here for short term fixes
# Hydra excluded because it returns a pytorch Tensor
EXCLUDE_ESTIMATORS = ["REDCOMETS", "HydraTransformer"]
EXCLUDE_ESTIMATORS = [
"REDCOMETS",
"HydraTransformer", # returns a pytorch Tensor
]

# Exclude specific tests for estimators here
EXCLUDED_TESTS = {
Expand All @@ -50,6 +52,7 @@
"RSASTClassifier": ["check_fit_deterministic"],
"SAST": ["check_fit_deterministic"],
"RSAST": ["check_fit_deterministic"],
"MatrixProfile": ["check_persistence_via_pickle"],
# missed in legacy testing, changes state in predict/transform
"FLUSSSegmenter": ["check_non_state_changing_method"],
"InformationGainSegmenter": ["check_non_state_changing_method"],
Expand Down
14 changes: 10 additions & 4 deletions aeon/testing/utils/deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _deep_equals(x, y, depth, ignore_index):
elif isinstance(x, pd.DataFrame):
return _dataframe_equals(x, y, depth, ignore_index)
elif isinstance(x, np.ndarray):
return _numpy_equals(x, y, depth)
return _numpy_equals(x, y, depth, ignore_index)
elif isinstance(x, (list, tuple)):
return _list_equals(x, y, depth, ignore_index)
elif isinstance(x, dict):
Expand Down Expand Up @@ -128,15 +128,21 @@ def _dataframe_equals(x, y, depth, ignore_index):
return eq, msg


def _numpy_equals(x, y, depth):
def _numpy_equals(x, y, depth, ignore_index):
if x.dtype != y.dtype:
return False, f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})"

if x.dtype == "object":
eq, msg = _deep_equals(x.tolist(), y.tolist(), depth, ignore_index=True)
for i in range(len(x)):
eq, msg = _deep_equals(x[i], y[i], depth + 1, ignore_index)

if not eq:
return False, msg + f", idx={i}"
else:
eq = np.allclose(x, y, equal_nan=True)
msg = "" if eq else f"x ({x}) != y ({y}), depth={depth}"
return eq, msg
return eq, msg
return True, ""


def _csrmatrix_equals(x, y, depth):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ addopts = '''
--dist worksteal
--reruns 2
--only-rerun "crashed while running"
--only-rerun "zipfile.BadZipFile"
'''
filterwarnings = '''
ignore::UserWarning
Expand Down

0 comments on commit dfc3aed

Please sign in to comment.