forked from moj-analytical-services/splink
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_columns_used.py
69 lines (56 loc) · 2.11 KB
/
test_columns_used.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from splink.parse_sql import get_columns_used_from_sql
def test_get_columns_used():
sql = """
jaro_winkler_sim(mytable.surname_l, surname_r) > 0.99 or
substr(mytable.surname_l || initial_l ,1,2) = substr(surname_r || initial_r,1,2)
"""
assert set(get_columns_used_from_sql(sql)) == set(
["surname_l", "surname_r", "initial_l", "initial_r"]
)
assert set(get_columns_used_from_sql(sql, retain_table_prefix=True)) == set(
["mytable.surname_l", "surname_r", "initial_l", "initial_r"]
)
sql = """
lat_lng_uncommon_l['lat'] - lat_lng_uncommon_r['lat']
"""
assert set(get_columns_used_from_sql(sql)) == set(
[
"lat_lng_uncommon_l",
"lat_lng_uncommon_r",
]
)
sql = """
transform(latlongexplode(lat_lng_arr_uncommon_l,lat_lng_arr_uncommon_r ),
x -> sin(radians(x['place2']['lat'] - x['place1']['lat'])) )
"""
assert set(get_columns_used_from_sql(sql)) == set(
[
"lat_lng_arr_uncommon_l",
"lat_lng_arr_uncommon_r",
]
)
sql = "AGGREGATE(cities, 0, (x, y) -> x + length(y))"
assert set(get_columns_used_from_sql(sql)) == set(
[
"cities",
]
)
sql = "AGGREGATE(cities, 0, x -> length(x['a']))"
assert set(get_columns_used_from_sql(sql)) == set(
[
"cities",
]
)
sql = """
ARRAY_MIN(TRANSFORM(LATLONGEXPLODE(lat_lng_arr_uncommon_l, lat_lng_arr_uncommon_r),
(x) -> (CAST(ATAN2(SQRT((POW(SIN(RADIANS(x['place2']['lat'] - x['place1']['lat']))
/ 2, 2) + COS(RADIANS(x['place1']['lat'])) * COS(RADIANS(x['place2']['lat']))
* POW(SIN(RADIANS(x['place2']['long'] - x['place1']['long']) / 2), 2))),
SQRT(-1 * (POW(SIN(RADIANS(x['place2']['lat'] - x['place1']['lat'])) / 2, 2) +
COS(RADIANS(x['place1']['lat'])) * COS(RADIANS(x['place2']['lat'])) *
POW(SIN(RADIANS(x['place2']['long'] - x['place1']['long']) / 2), 2)) + 1))
* 12742 AS FLOAT)))) < 5
"""
assert set(get_columns_used_from_sql(sql)) == set(
["lat_lng_arr_uncommon_l", "lat_lng_arr_uncommon_r"]
)