|
7 | 7 | from typing import Optional |
8 | 8 | import ast |
9 | 9 | import dataclasses |
10 | | -import glob |
11 | 10 | import json |
12 | 11 | import os |
13 | 12 | import re |
@@ -84,32 +83,44 @@ def load_questions(question_file: str, begin: Optional[int], end: Optional[int]) |
84 | 83 | return questions |
85 | 84 |
|
86 | 85 |
|
87 | | -def load_model_answers(answer_dir: str, model_name=None) -> dict: |
| 86 | +def load_model_answers(answer_dir: str, model_name=None, answer_file=None) -> dict: |
88 | 87 | """Load model answers. |
89 | 88 |
|
90 | 89 | The return value is a python dict of type: |
91 | 90 | Dict[model_name: str -> Dict[question_id: int -> answer: dict]] |
92 | 91 | """ |
93 | 92 | logger.debug(locals()) |
94 | 93 | model_answers = {} |
95 | | - for root, _, files in os.walk(answer_dir): |
96 | | - for filename in files: |
97 | | - if filename.endswith(".jsonl"): |
98 | | - # Removing ".jsonl" |
99 | | - file_model_name = filename[:-6] |
100 | | - answer = {} |
101 | | - file_path = os.path.join(root, filename) |
102 | | - with open(file_path, encoding="utf-8") as fin: |
103 | | - for line in fin: |
104 | | - l = json.loads(line) |
105 | | - answer[l["question_id"]] = l |
106 | | - model_answers[model_name or file_model_name] = answer |
107 | | - if model_name == file_model_name: |
108 | | - logger.debug("Found answer file matching: %s", model_name) |
109 | | - break |
| 94 | + if answer_file is not None: |
| 95 | + filename = os.path.basename(answer_file) |
| 96 | + # Removing ".jsonl" |
| 97 | + file_model_name = filename[:-6] |
| 98 | + model_answers[file_model_name] = _load_answers(answer_file) |
| 99 | + else: |
| 100 | + for root, _, files in os.walk(answer_dir): |
| 101 | + for filename in files: |
| 102 | + if filename.endswith(".jsonl"): |
| 103 | + # Removing ".jsonl" |
| 104 | + file_model_name = filename[:-6] |
| 105 | + file_path = os.path.join(root, filename) |
| 106 | + model_answers[model_name or file_model_name] = _load_answers( |
| 107 | + file_path |
| 108 | + ) |
| 109 | + if model_name == file_model_name: |
| 110 | + logger.debug("Found answer file matching: %s", model_name) |
| 111 | + break |
110 | 112 | return model_answers |
111 | 113 |
|
112 | 114 |
|
| 115 | +def _load_answers(answer_file): |
| 116 | + answers = {} |
| 117 | + with open(answer_file, encoding="utf-8") as fin: |
| 118 | + for line in fin: |
| 119 | + l = json.loads(line) |
| 120 | + answers[l["question_id"]] = l |
| 121 | + return answers |
| 122 | + |
| 123 | + |
113 | 124 | def load_judge_prompts(prompt_file: str) -> dict: |
114 | 125 | """Load judge prompts. |
115 | 126 |
|
@@ -304,8 +315,6 @@ def check_data(questions, model_answers, ref_answers, models, judges): |
304 | 315 | ), f"Missing reference answer to Question {q['question_id']} for judge {jg.model_name}" |
305 | 316 |
|
306 | 317 |
|
307 | | -def get_model_list(answer_dir): |
| 318 | +def get_model_list(answer_file): |
308 | 319 | logger.debug(locals()) |
309 | | - file_paths = glob.glob(f"{answer_dir}/*.jsonl") |
310 | | - file_names = [os.path.splitext(os.path.basename(f))[0] for f in file_paths] |
311 | | - return file_names |
| 320 | + return [os.path.splitext(os.path.basename(answer_file))[0]] |
0 commit comments