forked from salesforce/decaNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_to_logical_forms.py
81 lines (74 loc) · 3.15 KB
/
convert_to_logical_forms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3
from text.torchtext.datasets.generic import Query
from argparse import ArgumentParser
import os
import re
import ujson as json
from metrics import to_lf
def correct_format(x):
if len(x.keys()) == 0:
x = {'query': None, 'error': 'Invalid'}
else:
c = x['conds']
proper = True
for cc in c:
if len(cc) < 3:
proper = False
if proper:
x = {'query': x, 'error': ''}
else:
x = {'query': None, 'error': 'Invalid'}
return x
def write_logical_forms(greedy, args):
data_dir = os.path.join(args.data, 'wikisql', 'data')
path = os.path.join(data_dir, 'dev.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.jsonl')
table_path = os.path.join(data_dir, 'dev.tables.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.tables.jsonl')
with open(table_path) as tables_file:
tables = [json.loads(line) for line in tables_file]
id_to_tables = {x['id']: x for x in tables}
examples = []
with open(path) as example_file:
for line in example_file:
entry = json.loads(line)
table = id_to_tables[entry['table_id']]
sql = entry['sql']
header = table['header']
a = repr(Query.from_dict(entry['sql'], table['header']))
ex = {'sql': sql, 'header': header, 'answer': a, 'table': table}
examples.append(ex)
with open(args.output, 'a') as f:
count = 0
correct = 0
text_answers = []
for idx, (g, ex) in enumerate(zip(greedy, examples)):
count += 1
text_answers.append([ex['answer'].lower()])
try:
lf = to_lf(g, ex['table'])
f.write(json.dumps(correct_format(lf)) + '\n')
gt = ex['sql']
conds = gt['conds']
lower_conds = []
for c in conds:
lc = c
lc[2] = str(lc[2]).lower()
lower_conds.append(lc)
gt['conds'] = lower_conds
correct += lf == gt
except Exception as e:
f.write(json.dumps(correct_format({})) + '\n')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('data', help='path to the directory containing data for WikiSQL')
parser.add_argument('predictions', help='path to prediction file, containing one prediction per line')
parser.add_argument('ids', help='path to file for indices, a list of integers indicating the index into the dev/test set of the predictions on the corresponding line in \'predicitons\'')
parser.add_argument('output', help='path for logical forms output line by line')
parser.add_argument('evaluate', help='running on the \'validation\' or \'test\' set')
args = parser.parse_args()
with open(args.predictions) as f:
greedy = [l for l in f]
if args.ids is not None:
with open(args.ids) as f:
ids = [int(l.strip()) for l in f]
greedy = [x[1] for x in sorted([(i, g) for i, g in zip(ids, greedy)])]
write_logical_forms(greedy, args)