Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
example_answer_2=_example_answer_2,
example_answer_3=_example_answer_3)

def sql_user_question(self, current_time: str):
def sql_user_question(self, current_time: str, change_title: bool):
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
rule=self.rule, current_time=current_time, error_msg=self.error_msg)
rule=self.rule, current_time=current_time, error_msg=self.error_msg,change_title = change_title)

def chart_sys_question(self):
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
Expand Down
46 changes: 34 additions & 12 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def select_datasource(self, _session: Session):
def generate_sql(self, _session: Session):
# append current question
self.sql_message.append(HumanMessage(
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))))
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),change_title = self.change_title)))

self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=_session,
ai_modal_id=self.chat_question.ai_modal_id,
Expand Down Expand Up @@ -756,6 +756,26 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]:

return chart_type

@staticmethod
def get_brief_from_sql_answer(res: str) -> Optional[str]:
json_str = extract_nested_json(res)
if json_str is None:
return None

brief: Optional[str]
data: dict
try:
data = orjson.loads(json_str)

if data['success']:
brief = data['brief']
else:
return None
except Exception:
return None

return brief

def check_save_sql(self, session: Session, res: str) -> str:
sql, *_ = self.check_sql(res=res)
save_sql(session=session, sql=sql, record_id=self.record.id)
Expand Down Expand Up @@ -925,17 +945,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
if not stream:
json_result['record_id'] = self.get_record().id

# return title
if self.change_title:
if self.chat_question.question and self.chat_question.question.strip() != '':
brief = rename_chat(session=_session,
rename_object=RenameChat(id=self.get_record().chat_id,
brief=self.chat_question.question.strip()[:20]))
if in_chat:
yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n'
if not stream:
json_result['title'] = brief

# select datasource if datasource is none
if not self.ds:
ds_res = self.select_datasource(_session)
Expand Down Expand Up @@ -981,6 +990,19 @@ def run_task(self, in_chat: bool = True, stream: bool = True,

chart_type = self.get_chart_type_from_sql_answer(full_sql_text)

# return title
if self.change_title:
llm_brief = self.get_brief_from_sql_answer(full_sql_text)
if (llm_brief and llm_brief != '') or (self.chat_question.question and self.chat_question.question.strip() != ''):
save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[:20]
brief = rename_chat(session=_session,
rename_object=RenameChat(id=self.get_record().chat_id,
brief=save_brief))
if in_chat:
yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n'
if not stream:
json_result['title'] = brief

use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type in dynamic_ds_types
is_page_embedded: bool = self.current_assistant and self.current_assistant.type == 4
dynamic_sql_result = None
Expand Down
13 changes: 10 additions & 3 deletions backend/templates/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ template:
<step>4. 应用其他规则(引号、别名等)</step>
<step>5. <strong>强制检查:检查语法是否正确?</strong></step>
<step>6. 确定图表类型</step>
<step>7. 返回JSON结果</step>
<step>7. 确定对话标题</step>
<step>8. 返回JSON结果</step>
</SQL-Generation-Process>
query_limit: |
<rule priority="critical" id="data-limit-policy">
Expand All @@ -41,7 +42,7 @@ template:
system: |
<Instruction>
你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。
你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。
你当前的任务是根据给定的表结构和用户问题生成SQL语句、对话标题、可能适合展示的图表类型以及该SQL中所用到的表名。
我们会在<Info>块内提供给你信息,帮助你生成SQL:
<Info>内有<db-engine><m-schema><terminologies>等信息;
其中,<db-engine>:提供数据库引擎及版本信息;
Expand Down Expand Up @@ -72,7 +73,7 @@ template:
</rule>
<rule>
请使用JSON格式返回你的回答:
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table","brief":"如何需要生成对话标题,在这里填写你生成的对话标题,否则不需要这个字段"}}
若不能生成,则返回格式如:{{"success":false,"message":"说明无法生成SQL的原因"}}
</rule>
<rule>
Expand Down Expand Up @@ -112,6 +113,9 @@ template:
<rule>
我们目前的情况适用于单指标、多分类的场景(展示table除外)
</rule>
<rule>
是否生成对话标题在<change-title>内,如果为True需要生成,否则不需要生成,生成的对话标题要求在20字以内
</rule>
</Rules>

{process_check}
Expand Down Expand Up @@ -251,6 +255,9 @@ template:
<user-question>
{question}
</user-question>
<change-title>
{change_title}
</change-title>

chart:
system: |
Expand Down