forked from moj-analytical-services/splink
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_caching.py
280 lines (219 loc) · 9.8 KB
/
test_caching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import os
from unittest.mock import create_autospec, patch
import pandas as pd
import pytest
from splink.internals.duckdb.database_api import DuckDBAPI
from splink.internals.duckdb.dataframe import DuckDBDataFrame
from splink.internals.linker import Linker, SplinkDataFrame
from splink.internals.pipeline import CTEPipeline
from splink.internals.vertically_concatenate import (
compute_df_concat,
compute_df_concat_with_tf,
enqueue_df_concat_with_tf,
)
from tests.basic_settings import get_settings_dict
df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")
_dummy_pd_frame = pd.DataFrame(["id"])
def make_mock_execute(db_api):
# creates a mock version of linker._sql_to_splink_dataframe,
# so we can count calls
dummy_table_name = "__splink__dummy_frame"
dummy_splink_df = DuckDBDataFrame("template", dummy_table_name, db_api)
def register_and_return_dummy_frame(*args, **kwargs):
# need to make sure that the dummy frame always exist in the context
# we are running tests
# not actually interested in the frame itself, but needs to exist in
# connexion in case a method tries to access it
db_api._con.sql(
f"CREATE TABLE IF NOT EXISTS {dummy_table_name} AS "
f"SELECT * FROM _dummy_pd_frame"
)
return dummy_splink_df
mock_execute = create_autospec(
db_api._sql_to_splink_dataframe, side_effect=register_and_return_dummy_frame
)
return mock_execute
def test_cache_id(tmp_path):
# Test saving and loading a model
db_api = DuckDBAPI()
linker = Linker(df, get_settings_dict(), db_api=db_api)
prior = linker._settings_obj._cache_uid
path = os.path.join(tmp_path, "model.json")
linker.misc.save_model_to_json(path, overwrite=True)
db_api = DuckDBAPI()
linker_2 = Linker(df, path, db_api=db_api)
assert linker_2._settings_obj._cache_uid == prior
# Test uid from settings
random_uid = "my_random_uid"
settings = get_settings_dict()
settings["linker_uid"] = random_uid
db_api = DuckDBAPI()
linker = Linker(df, settings, db_api=db_api)
linker_uid = linker._cache_uid
assert linker_uid == random_uid
def test_cache_only_splink_dataframes():
settings = get_settings_dict()
db_api = DuckDBAPI()
linker = Linker(df, settings, db_api=db_api)
linker._intermediate_table_cache["new_table"] = DuckDBDataFrame(
"template", "__splink__dummy_frame", linker
)
try:
linker._intermediate_table_cache["not_a_table"] = 30
except TypeError:
# error is raised, but need to check it hasn't made it to the cache
pass
for _, table in linker._intermediate_table_cache.items():
assert isinstance(table, SplinkDataFrame)
# run test in/not in debug mode to check functionality in both - cache shouldn't care
@pytest.mark.parametrize("debug_mode", (False, True))
def test_cache_access_df_concat(debug_mode):
settings = get_settings_dict()
db_api = DuckDBAPI()
linker = Linker(df, settings, db_api=db_api)
linker._debug_mode = debug_mode
with patch.object(
db_api, "_sql_to_splink_dataframe", new=make_mock_execute(db_api)
) as mockexecute_sql_pipeline:
# shouldn't touch DB if we don't materialise
pipeline = CTEPipeline()
pipeline = enqueue_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_not_called()
# this should create the table in the db
compute_df_concat_with_tf(linker, pipeline)
# NB don't specify amount of times it is called, as will depend on debug_mode
mockexecute_sql_pipeline.assert_called()
# reset the call counter on the mock
mockexecute_sql_pipeline.reset_mock()
# this should NOT touch the database, but instead use the cache
compute_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_not_called()
# this should also use the cache - concat will just refer to concat_with_tf
compute_df_concat(linker, pipeline)
mockexecute_sql_pipeline.assert_not_called()
@pytest.mark.parametrize("debug_mode", (False, True))
def test_cache_access_compute_tf_table(debug_mode):
settings = get_settings_dict()
db_api = DuckDBAPI()
linker = Linker(df, settings, db_api=db_api)
linker._debug_mode = debug_mode
with patch.object(
db_api, "_sql_to_splink_dataframe", new=make_mock_execute(db_api)
) as mockexecute_sql_pipeline:
linker.table_management.compute_tf_table("first_name")
mockexecute_sql_pipeline.assert_called()
# reset the call counter on the mock
mockexecute_sql_pipeline.reset_mock()
linker.table_management.compute_tf_table("first_name")
mockexecute_sql_pipeline.assert_not_called()
@pytest.mark.parametrize("debug_mode", (False, True))
def test_invalidate_cache(debug_mode):
settings = get_settings_dict()
db_api = DuckDBAPI()
linker = Linker(df, settings, db_api=db_api)
linker._debug_mode = debug_mode
with patch.object(
db_api, "_sql_to_splink_dataframe", new=make_mock_execute(db_api)
) as mockexecute_sql_pipeline:
pipeline = CTEPipeline()
compute_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_called()
mockexecute_sql_pipeline.reset_mock()
# this should NOT touch the database, but instead use the cache
pipeline = CTEPipeline()
compute_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_not_called()
# create this:
linker.table_management.compute_tf_table("surname")
mockexecute_sql_pipeline.assert_called()
mockexecute_sql_pipeline.reset_mock()
# then check the cache
linker.table_management.compute_tf_table("surname")
mockexecute_sql_pipeline.assert_not_called()
linker.table_management.invalidate_cache()
# now we _SHOULD_ compute afresh:
pipeline = CTEPipeline()
compute_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_called()
mockexecute_sql_pipeline.reset_mock()
# but now draw from the cache
pipeline = CTEPipeline()
compute_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_not_called()
# and should compute this again:
linker.table_management.compute_tf_table("surname")
mockexecute_sql_pipeline.assert_called()
mockexecute_sql_pipeline.reset_mock()
# then check the cache
linker.table_management.compute_tf_table("surname")
mockexecute_sql_pipeline.assert_not_called()
@pytest.mark.parametrize("debug_mode", (False, True))
def test_cache_invalidates_with_new_linker(debug_mode):
settings = get_settings_dict()
db_api = DuckDBAPI()
linker = Linker(df, settings, db_api=db_api)
linker._debug_mode = debug_mode
with patch.object(
db_api, "_sql_to_splink_dataframe", new=make_mock_execute(db_api)
) as mockexecute_sql_pipeline:
pipeline = CTEPipeline()
compute_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_called()
mockexecute_sql_pipeline.reset_mock()
# should use cache
pipeline = CTEPipeline()
compute_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_not_called()
db_api = DuckDBAPI()
new_linker = Linker(df, settings, db_api=db_api)
new_linker._debug_mode = debug_mode
with patch.object(
db_api, "_sql_to_splink_dataframe", new=make_mock_execute(db_api)
) as mockexecute_sql_pipeline:
# new linker should recalculate df_concat_with_tf
pipeline = CTEPipeline()
compute_df_concat_with_tf(new_linker, pipeline)
mockexecute_sql_pipeline.assert_called()
mockexecute_sql_pipeline.reset_mock()
# but now read from the cache
pipeline = CTEPipeline()
compute_df_concat_with_tf(new_linker, pipeline)
mockexecute_sql_pipeline.assert_not_called()
with patch.object(
db_api, "_sql_to_splink_dataframe", new=make_mock_execute(db_api)
) as mockexecute_sql_pipeline:
# original linker should still have result cached
pipeline = CTEPipeline()
compute_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_not_called()
@pytest.mark.parametrize("debug_mode", (False, True))
def test_cache_register_compute_concat_with_tf_table(debug_mode):
settings = get_settings_dict()
db_api = DuckDBAPI()
linker = Linker(df, settings, db_api=db_api)
linker._debug_mode = debug_mode
with patch.object(
db_api, "_sql_to_splink_dataframe", new=make_mock_execute(db_api)
) as mockexecute_sql_pipeline:
# can actually register frame, as that part not cached
# don't need function so use any frame
linker.table_management.register_table_input_nodes_concat_with_tf(df)
# now this should be cached, as I have manually registered
pipeline = CTEPipeline()
compute_df_concat_with_tf(linker, pipeline)
mockexecute_sql_pipeline.assert_not_called()
@pytest.mark.parametrize("debug_mode", (False, True))
def test_cache_register_compute_tf_table(debug_mode):
settings = get_settings_dict()
db_api = DuckDBAPI()
linker = Linker(df, settings, db_api=db_api)
linker._debug_mode = debug_mode
with patch.object(
db_api, "_sql_to_splink_dataframe", new=make_mock_execute(db_api)
) as mockexecute_sql_pipeline:
# can actually register frame, as that part not cached
# don't need function so use any frame
linker.table_management.register_term_frequency_lookup(df, "first_name")
# now this should be cached, as I have manually registered
linker.table_management.compute_tf_table("first_name")
mockexecute_sql_pipeline.assert_not_called()