@@ -21,72 +21,57 @@ def build_prompt(
2121
2222 @staticmethod
2323 @abstractmethod
24- def parse_response (response : str ) -> Any :
24+ def parse_response (response : str ) -> list [ dict ] :
2525 """Parse the LLM response and return the generated QAs"""
2626
2727 async def generate (
2828 self ,
2929 batch : tuple [
3030 list [tuple [str , dict ]], list [tuple [Any , Any , dict ] | tuple [Any , Any , Any ]]
3131 ],
32- ) -> dict [ str , Any ]:
32+ ) -> list [ dict ]:
3333 """
3434 Generate QAs based on a given batch.
3535 :param batch
3636 :return: QA pairs
3737 """
38- result = {}
3938 prompt = self .build_prompt (batch )
4039 response = await self .llm_client .generate_answer (prompt )
4140 qa_pairs = self .parse_response (response ) # generate one or more QA pairs
42- result .update (qa_pairs )
43- return result
41+ return qa_pairs
4442
4543 @staticmethod
4644 def format_generation_results (
47- results : list [dict ], output_data_format : str
48- ) -> list [dict [str , Any ]]:
45+ result : dict , output_data_format : str
46+ ) -> dict [str , Any ]:
47+ question = result .get ("question" , "" )
48+ answer = result .get ("answer" , "" )
49+ if "options" in result and result ["options" ]:
50+ options = result ["options" ]
51+ options_str = "\n " .join (
52+ [f"{ key } . { options [key ]} " for key in sorted (options .keys ())]
53+ )
54+ question += f"\n Options:\n { options_str } "
4955
50- flat_results = []
51- for item in results :
52- for _ , qa_data in item .items ():
53- question = qa_data .get ("question" , "" )
54- answer = qa_data .get ("answer" , "" )
55- if "options" in qa_data and qa_data ["options" ]:
56- options = qa_data ["options" ]
57- options_str = "\n " .join (
58- [f"{ key } . { options [key ]} " for key in sorted (options .keys ())]
59- )
60- question += f"\n Options:\n { options_str } "
56+ if output_data_format == "Alpaca" :
57+ return {
58+ "instruction" : question ,
59+ "input" : "" ,
60+ "output" : answer ,
61+ }
6162
62- if output_data_format == "Alpaca" :
63- flat_results .append (
64- {
65- "instruction" : question ,
66- "input" : "" ,
67- "output" : answer ,
68- }
69- )
70- elif output_data_format == "Sharegpt" :
71- flat_results .append (
72- {
73- "conversations" : [
74- {"from" : "human" , "value" : question },
75- {"from" : "gpt" , "value" : answer },
76- ]
77- }
78- )
79- elif output_data_format == "ChatML" :
80- flat_results .append (
81- {
82- "messages" : [
83- {"role" : "user" , "content" : question },
84- {"role" : "assistant" , "content" : answer },
85- ]
86- }
87- )
88- else :
89- raise ValueError (
90- f"Unknown output data format: { output_data_format } "
91- )
92- return flat_results
63+ if output_data_format == "Sharegpt" :
64+ return {
65+ "conversations" : [
66+ {"from" : "human" , "value" : question },
67+ {"from" : "gpt" , "value" : answer },
68+ ]
69+ }
70+ if output_data_format == "ChatML" :
71+ return {
72+ "messages" : [
73+ {"role" : "user" , "content" : question },
74+ {"role" : "assistant" , "content" : answer },
75+ ]
76+ }
77+ raise ValueError (f"Unknown output data format: { output_data_format } " )
0 commit comments