forked from rxn4chemistry/rxnmapper
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
37 lines (29 loc) · 1.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import Any, Dict, Iterable
import numpy as np
def assert_correct_map(value_1: Dict[str, Any], value_2: Dict[str, Any]) -> None:
mandatory_keys = ["mapped_rxn", "confidence"]
# Exact matches
for key in ["mapped_rxn", "pxr_mapping_vector", "tokens"]:
if key not in mandatory_keys and key not in value_1:
continue
assert value_1[key] == value_2[key]
# close match on single number
for key in ["confidence"]:
if key not in mandatory_keys and key not in value_1:
continue
assert np.isclose(value_1[key], value_2[key])
# close match on multiple values
for key in [
"pxr_confidences",
"pxrrxp_attns",
"tokensxtokens_attns",
"mapping_tuples",
]:
if key not in mandatory_keys and key not in value_1:
continue
assert np.allclose(value_1[key], value_2[key], rtol=1e-4, atol=1e-7)
def assert_correct_maps(
values_1: Iterable[Dict[str, Any]], values_2: Iterable[Dict[str, Any]]
) -> None:
for value_1, value_2 in zip(values_1, values_2):
assert_correct_map(value_1, value_2)