Skip to content

Commit ff54038

Browse files
authored
Merge pull request #110 from danmcp/singleanswerfile
Use single answer file and model list
2 parents d272c80 + 39b6960 commit ff54038

File tree

3 files changed

+33
-30
lines changed

3 files changed

+33
-30
lines changed

.github/workflows/e2e.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
- name: Run e2e test
9595
run: |
9696
. venv/bin/activate
97-
./instructlab/scripts/basic-workflow-tests.sh -cm
97+
./instructlab/scripts/basic-workflow-tests.sh -m
9898
9999
- name: Remove llama-cpp-python from cache
100100
if: always()

src/instructlab/eval/mt_bench_common.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Optional
88
import ast
99
import dataclasses
10-
import glob
1110
import json
1211
import os
1312
import re
@@ -84,32 +83,44 @@ def load_questions(question_file: str, begin: Optional[int], end: Optional[int])
8483
return questions
8584

8685

87-
def load_model_answers(answer_dir: str, model_name=None) -> dict:
86+
def load_model_answers(answer_dir: str, model_name=None, answer_file=None) -> dict:
8887
"""Load model answers.
8988
9089
The return value is a python dict of type:
9190
Dict[model_name: str -> Dict[question_id: int -> answer: dict]]
9291
"""
9392
logger.debug(locals())
9493
model_answers = {}
95-
for root, _, files in os.walk(answer_dir):
96-
for filename in files:
97-
if filename.endswith(".jsonl"):
98-
# Removing ".jsonl"
99-
file_model_name = filename[:-6]
100-
answer = {}
101-
file_path = os.path.join(root, filename)
102-
with open(file_path, encoding="utf-8") as fin:
103-
for line in fin:
104-
l = json.loads(line)
105-
answer[l["question_id"]] = l
106-
model_answers[model_name or file_model_name] = answer
107-
if model_name == file_model_name:
108-
logger.debug("Found answer file matching: %s", model_name)
109-
break
94+
if answer_file is not None:
95+
filename = os.path.basename(answer_file)
96+
# Removing ".jsonl"
97+
file_model_name = filename[:-6]
98+
model_answers[file_model_name] = _load_answers(answer_file)
99+
else:
100+
for root, _, files in os.walk(answer_dir):
101+
for filename in files:
102+
if filename.endswith(".jsonl"):
103+
# Removing ".jsonl"
104+
file_model_name = filename[:-6]
105+
file_path = os.path.join(root, filename)
106+
model_answers[model_name or file_model_name] = _load_answers(
107+
file_path
108+
)
109+
if model_name == file_model_name:
110+
logger.debug("Found answer file matching: %s", model_name)
111+
break
110112
return model_answers
111113

112114

115+
def _load_answers(answer_file):
116+
answers = {}
117+
with open(answer_file, encoding="utf-8") as fin:
118+
for line in fin:
119+
l = json.loads(line)
120+
answers[l["question_id"]] = l
121+
return answers
122+
123+
113124
def load_judge_prompts(prompt_file: str) -> dict:
114125
"""Load judge prompts.
115126
@@ -304,8 +315,6 @@ def check_data(questions, model_answers, ref_answers, models, judges):
304315
), f"Missing reference answer to Question {q['question_id']} for judge {jg.model_name}"
305316

306317

307-
def get_model_list(answer_dir):
318+
def get_model_list(answer_file):
308319
logger.debug(locals())
309-
file_paths = glob.glob(f"{answer_dir}/*.jsonl")
310-
file_names = [os.path.splitext(os.path.basename(f))[0] for f in file_paths]
311-
return file_names
320+
return [os.path.splitext(os.path.basename(answer_file))[0]]

src/instructlab/eval/mt_bench_judgment.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def judge_model(
155155
bench_name="mt_bench",
156156
output_dir="eval_output",
157157
data_dir=None,
158-
model_list=None,
159158
max_workers=1,
160159
first_n=None,
161160
merge_system_user_message=False,
@@ -180,7 +179,7 @@ def judge_model(
180179
questions = load_questions(question_file, None, None)
181180

182181
# Load answers
183-
model_answers = load_model_answers(answer_dir)
182+
model_answers = load_model_answers(answer_dir, answer_file=answer_file)
184183
ref_answers = load_model_answers(ref_answer_dir, judge_model_name)
185184

186185
# Load judge
@@ -189,10 +188,7 @@ def judge_model(
189188
if first_n:
190189
questions = questions[:first_n]
191190

192-
if model_list is None:
193-
models = get_model_list(answer_dir)
194-
else:
195-
models = model_list
191+
models = get_model_list(answer_file)
196192

197193
judges = make_judge_single(judge_model_name, judge_prompts)
198194
output_file = f"{output_base_dir}/model_judgment/{judge_model_name}_single.jsonl"
@@ -280,7 +276,6 @@ def generate_judgment(
280276
output_dir="eval_output",
281277
data_dir=None,
282278
branch=None,
283-
model_list=None,
284279
max_workers=1,
285280
first_n=None,
286281
merge_system_user_message=False,
@@ -302,7 +297,6 @@ def generate_judgment(
302297
output_dir=output_dir,
303298
data_dir=data_dir,
304299
branch=branch,
305-
model_list=model_list,
306300
max_workers=max_workers,
307301
first_n=first_n,
308302
merge_system_user_message=merge_system_user_message,

0 commit comments

Comments
 (0)