forked from AllenInstitute/AllenSDK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
comparison_utils.py
70 lines (65 loc) · 2.31 KB
/
comparison_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import datetime
import math
from typing import Any, Optional, Set
import SimpleITK as sitk
import numpy as np
import pandas as pd
import xarray as xr
from pandas.util.testing import assert_frame_equal
def compare_fields(x1: Any, x2: Any, err_msg="",
ignore_keys: Optional[Set[str]] = None):
"""Helper function to compare if two fields (attributes)
are equal to one another.
Parameters
----------
x1 : Any
The first field
x2 : Any
The other field
err_msg : str, optional
The error message to display if two compared fields do not equal
one another, by default "" (an empty string)
ignore_keys
For dictionary comparison, ignore these keys
"""
if ignore_keys is None:
ignore_keys = set()
if isinstance(x1, pd.DataFrame):
try:
assert_frame_equal(x1, x2, check_like=True)
except Exception:
print(err_msg)
raise
elif isinstance(x1, np.ndarray):
np.testing.assert_array_almost_equal(x1, x2, err_msg=err_msg)
elif isinstance(x1, xr.DataArray):
xr.testing.assert_allclose(x1, x2)
elif isinstance(x1, (list, tuple)):
assert len(x1) == len(x2)
for i in range(len(x1)):
compare_fields(x1=x1[i], x2=x2[i])
elif isinstance(x1, (sitk.Image,)):
assert x1.GetSize() == x2.GetSize(), err_msg
assert x1 == x2, err_msg
elif isinstance(x1, (datetime.datetime, pd.Timestamp)):
if isinstance(x1, pd.Timestamp):
x1 = x1.to_pydatetime()
if isinstance(x2, pd.Timestamp):
x2 = x2.to_pydatetime()
time_delta = (x1 - x2).total_seconds()
# Timestamp differences should be less than 60 seconds
assert abs(time_delta) < 60
elif isinstance(x1, (float,)):
if math.isnan(x1) or math.isnan(x2):
both_nan = (math.isnan(x1) and math.isnan(x2))
assert both_nan, err_msg
else:
assert x1 == x2, err_msg
elif isinstance(x1, (dict,)):
for key in set(x1.keys()).union(set(x2.keys())):
if key in ignore_keys:
continue
key_err_msg = f"Mismatch when checking key {key}. {err_msg}"
compare_fields(x1[key], x2[key], err_msg=key_err_msg)
else:
assert x1 == x2, err_msg