-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocess.py
More file actions
75 lines (57 loc) · 1.96 KB
/
preprocess.py
File metadata and controls
75 lines (57 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import json
from transformer.Constants import SQL_SEPARATOR
agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
cond_ops = ['=', '>', '<', 'OP']
queryset = []
tableset = []
table_dict = {}
def process(sentence_file, sql_file):
for rec in queryset:
rec_dict = json.loads(rec)
sentence_file.write(rec_dict['question'].strip() + "\n")
sql = []
sql.append("select")
agg = agg_ops[rec_dict['sql']['agg']]
if agg != '':
sql.append(agg)
table_id = rec_dict['table_id']
column = table_dict[table_id][rec_dict['sql']['sel']]
sql.append(column)
sql.append('from table')
sql.append('where')
conds = rec_dict['sql']['conds']
transf_conds = []
for c in conds:
transf_cond = []
transf_cond.append(table_dict[table_id][c[0]])
transf_cond.append(cond_ops[c[1]])
transf_cond.append(str(c[2]))
transf_conds.append(SQL_SEPARATOR.join(transf_cond))
sql.append("|AND|".join(transf_conds))
sql_file.write(SQL_SEPARATOR.join(sql) + "\n")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--queries', required=True)
parser.add_argument('--tables', required=True)
# path without extension !
parser.add_argument('--output', required=True)
opt = parser.parse_args()
with open(opt.queries) as qfile:
for line in qfile:
queryset.append(line.rstrip())
qfile.close()
with open(opt.tables) as tbfile:
for line in tbfile:
tableset.append(line.rstrip())
tbfile.close()
for rec in tableset:
rec_dict = json.loads(rec)
table_dict[rec_dict['id']] = rec_dict['header']
sentence_file = open(opt.output + ".en", "w+")
sql_file = open(opt.output + ".sql", "w+")
process(sentence_file, sql_file)
sentence_file.close()
sql_file.close()
if __name__ == '__main__':
main()