forked from MolecularAI/Chemformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
154 lines (120 loc) · 4.34 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import pathlib
from argparse import Namespace
import numpy as np
import omegaconf as oc
import pandas as pd
import pytest
import molbart.utils.data_utils as util
from molbart.models import Chemformer
from molbart.data import SynthesisDataModule
from molbart.utils.tokenizers import ChemformerTokenizer, SpanTokensMasker
@pytest.fixture
def example_tokens():
return [
["^", "C", "(", "=", "O", ")", "unknown", "&"],
["^", "C", "C", "<SEP>", "C", "Br", "&"],
]
@pytest.fixture
def regex_tokens():
regex = r"\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]"
return regex.split("|")
@pytest.fixture
def smiles_data():
return ["CCO.Ccc", "CCClCCl", "C(=O)CBr"]
@pytest.fixture
def mock_random_choice(mocker):
class ToggleBool:
def __init__(self):
self.state = True
def __call__(self, *args, **kwargs):
states = []
for _ in range(kwargs["k"]):
states.append(self.state)
self.state = not self.state
return states
mocker.patch("molbart.utils.tokenizers.tokenizers.random.choices", side_effect=ToggleBool())
@pytest.fixture
def setup_tokenizer(regex_tokens, smiles_data):
def wrapper(tokens=None):
return ChemformerTokenizer(smiles=smiles_data, tokens=tokens, regex_token_patterns=regex_tokens)
return wrapper
@pytest.fixture
def setup_masker(setup_tokenizer):
def wrapper(cls=SpanTokensMasker):
tokenizer = setup_tokenizer()
return tokenizer, cls(tokenizer)
return wrapper
@pytest.fixture
def round_trip_params(shared_datadir):
params = {
"n_samples": 3,
"beam_size": 5,
"batch_size": 3,
"round_trip_input_data": shared_datadir / "round_trip_input_data.csv",
}
return params
@pytest.fixture
def round_trip_namespace_args(shared_datadir):
args = Namespace()
args.input_data = shared_datadir / "example_data_uspto.csv"
args.backward_predictions = shared_datadir / "example_data_backward_sampled_smiles_uspto50k.json"
args.output_score_data = "temp_metrics.csv"
args.dataset_part = "test"
args.working_directory = "tests"
args.target_column = "products"
return args
@pytest.fixture
def round_trip_raw_prediction_data(shared_datadir):
round_trip_df = pd.read_json(shared_datadir / "round_trip_predictions_raw.json", orient="table")
round_trip_predictions = [np.array(smiles_lst) for smiles_lst in round_trip_df["round_trip_smiles"].values]
data = {
"sampled_smiles": round_trip_predictions,
"target_smiles": round_trip_df["target_smiles"].values,
}
return data
@pytest.fixture
def round_trip_converted_prediction_data(shared_datadir):
round_trip_df = pd.read_json(shared_datadir / "round_trip_predictions_converted.json", orient="table")
round_trip_predictions = [np.array(smiles_lst) for smiles_lst in round_trip_df["round_trip_smiles"].values]
data = {
"sampled_smiles": round_trip_predictions,
"target_smiles": round_trip_df["target_smiles"].values,
}
return data
@pytest.fixture
def model_batch_setup(round_trip_namespace_args):
config = oc.OmegaConf.load("molbart/config/round_trip_inference.yaml")
data = pd.read_csv(round_trip_namespace_args.input_data, sep="\t")
config.d_model = 4
config.batch_size = 3
config.n_beams = 3
config.n_layers = 1
config.n_heads = 2
config.d_feedforward = 2
config.task = "forward_prediction"
config.datamodule = None
config.vocabulary_path = "bart_vocab_downstream.json"
config.n_gpus = 0
config.device = "cpu"
config.data_device = "cpu"
chemformer = Chemformer(config)
datamodule = SynthesisDataModule(
reactants=data["reactants"].values,
products=data["products"].values,
dataset_path="",
tokenizer=chemformer.tokenizer,
batch_size=config.batch_size,
max_seq_len=util.DEFAULT_MAX_SEQ_LEN,
reverse=False,
)
datamodule.setup()
dataloader = datamodule.full_dataloader()
batch_idx, batch_input = next(enumerate(dataloader))
output_data = {
"chemformer": chemformer,
"tokenizer": chemformer.tokenizer,
"batch_idx": batch_idx,
"batch_input": batch_input,
"max_seq_len": util.DEFAULT_MAX_SEQ_LEN,
}
return output_data