-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_lstm_features.py
61 lines (46 loc) · 1.54 KB
/
test_lstm_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pytest
from route_distances.lstm.features import (
add_fingerprints,
remove_reactions,
preprocess_reaction_tree,
)
@pytest.fixture
def toy_tree():
return {
"smiles": "CCn1nc(CC(C)C)cc1C(=O)NCc1c(C)cc(C)nc1OC",
"type": "mol",
"children": [
{
"smiles": "dummy",
"type": "reaction",
"children": [
{
"smiles": "CCn1nc(CC(C)C)cc1C(=O)O",
"type": "mol",
},
{
"smiles": "COc1nc(C)cc(C)c1CN",
"type": "mol",
},
],
}
],
}
def test_remove_reactions(toy_tree):
assert len(toy_tree["children"]) == 1
new_tree = remove_reactions(toy_tree)
assert len(new_tree["children"]) == 2
def test_add_fingerprints(toy_tree):
new_tree = remove_reactions(toy_tree)
add_fingerprints(new_tree, nbits=10)
assert len(new_tree["fingerprint"]) == 10
assert list(new_tree["fingerprint"]) == [1] * 10
assert list(new_tree["children"][0]["fingerprint"]) == [1] * 10
def test_preprocessing(toy_tree):
output = preprocess_reaction_tree(toy_tree, nfeatures=10)
assert len(output["features"]) == 3
assert list(output["node_order"]) == [1, 0, 0]
assert list(output["edge_order"]) == [1, 1]
assert list(output["adjacency_list"][0]) == [0, 1]
assert list(output["adjacency_list"][1]) == [0, 2]
assert output["num_nodes"] == 3