forked from ShayanTalaei/CHESS
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset4hf.py
More file actions
99 lines (82 loc) · 3.84 KB
/
dataset4hf.py
File metadata and controls
99 lines (82 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import re
import csv
def parse_log_file(log_file_path):
with open(log_file_path, 'r') as file:
log_content = file.read()
steps = []
step_pattern = re.compile(r'##############################\s*(Human|AI) at step (.*?)\s*##############################')
token_pattern = re.compile(r'######The token count is:\s*(\d+)######')
content_pattern = re.compile(r'######The token count is:\s*\d+######\n\n(.*?)\n\n(?=##############################|$)', re.DOTALL)
headers = list(step_pattern.finditer(log_content))
for i, header in enumerate(headers):
role = header.group(1)
step_name = header.group(2).strip()
start_index = header.end()
end_index = headers[i+1].start() if i + 1 < len(headers) else len(log_content)
section_text = log_content[start_index:end_index]
token_match = token_pattern.search(section_text)
tokens = int(token_match.group(1)) if token_match else 0
content_match = content_pattern.search(section_text)
content = content_match.group(1).strip() if content_match else ""
if role == "Human":
steps.append({
"step": step_name,
"input_content": content,
"output_content": "",
"input_length": tokens,
"output_length": 0
})
else: # role == "AI"
if steps and steps[-1]["step"] == step_name and steps[-1]["output_length"] == 0:
steps[-1]["output_content"] = content
steps[-1]["output_length"] = tokens
else:
steps.append({
"step": step_name,
"input_content": "",
"output_content": content,
"input_length": 0,
"output_length": tokens
})
return steps
def collect_logs(logs_directory):
logs_data = []
request_id_counter = 1
for log_file in os.listdir(logs_directory):
if log_file.endswith("formula_1.log") or log_file.endswith("financial.log"):
log_file_path = os.path.join(logs_directory, log_file)
steps = parse_log_file(log_file_path)
log_name = f'Text2SQLRequest_{request_id_counter}'
logs_data.extend([{
"Text2SQLRequest_id": log_name,
"step_name": step["step"],
"input_content": step["input_content"],
"output_content": step["output_content"],
"input_length": step["input_length"],
"output_length": step["output_length"]
} for step in steps])
request_id_counter += 1
return logs_data
def save_to_csv(data, output_file_path):
fieldnames = ["Text2SQLRequest_id", "step_name", "input_content", "output_content", "input_length", "output_length"]
with open(output_file_path, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_ALL)
writer.writeheader()
for row in data:
# Ensure proper escaping of special characters
row["input_content"] = row["input_content"].replace('\n', '\\n').replace('\r', '\\r')
row["output_content"] = row["output_content"].replace('\n', '\\n').replace('\r', '\\r')
writer.writerow(row)
def main():
logs_directories = [
'./results/dev/CHESS_IR_CG_UT/mixed_dev_1/2025-04-08T12:58:08.154132/logs',
'./results/dev/CHESS_IR_CG_UT/financial_dev/2025-04-06T13:35:21.873213/logs'
]
all_logs_data = []
for directory in logs_directories:
all_logs_data.extend(collect_logs(directory))
output_file_path = './text2sql_trace.csv'
save_to_csv(all_logs_data, output_file_path)
if __name__ == "__main__":
main()