Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import io
import traceback
from typing import Optional

import orjson
import pandas as pd
Expand Down Expand Up @@ -107,7 +108,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser):

@router.post("/recommend_questions/{chat_record_id}")
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
current_assistant: CurrentAssistant):
current_assistant: CurrentAssistant, articles_number: Optional[int] = 4):
def _return_empty():
yield 'data:' + orjson.dumps({'content': '[]', 'type': 'recommended_question'}).decode() + '\n\n'

Expand All @@ -121,6 +122,7 @@ def _return_empty():

llm_service = await LLMService.create(session, current_user, request_question, current_assistant, True)
llm_service.set_record(record)
llm_service.set_articles_number(articles_number)
llm_service.run_recommend_questions_task_async()
except Exception as e:
traceback.print_exc()
Expand Down
7 changes: 4 additions & 3 deletions backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T

def sql_user_question(self, current_time: str, change_title: bool):
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
rule=self.rule, current_time=current_time, error_msg=self.error_msg,change_title = change_title)
rule=self.rule, current_time=current_time, error_msg=self.error_msg,
change_title=change_title)

def chart_sys_question(self):
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
Expand Down Expand Up @@ -240,8 +241,8 @@ def datasource_sys_question(self):
def datasource_user_question(self, datasource_list: str = "[]"):
return get_datasource_template()['user'].format(question=self.question, data=datasource_list)

def guess_sys_question(self):
return get_guess_question_template()['system'].format(lang=self.lang)
def guess_sys_question(self, articles_number: int = 4):
return get_guess_question_template()['system'].format(lang=self.lang, articles_number=articles_number)

def guess_user_question(self, old_questions: str = "[]"):
return get_guess_question_template()['user'].format(question=self.question, schema=self.db_schema,
Expand Down
6 changes: 5 additions & 1 deletion backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class LLMService:
future: Future

last_execute_sql_error: str = None
articles_number: int = 4

def __init__(self, session: Session, current_user: CurrentUser, chat_question: ChatQuestion,
current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False,
Expand Down Expand Up @@ -213,6 +214,9 @@ def get_record(self):
def set_record(self, record: ChatRecord):
self.record = record

def set_articles_number(self, articles_number: int):
self.articles_number = articles_number

def get_fields_from_chart(self, _session: Session):
chart_info = get_chart_config(_session, self.record.id)
return format_chart_fields(chart_info)
Expand Down Expand Up @@ -330,7 +334,7 @@ def generate_recommend_questions_task(self, _session: Session):
embedding=False)

guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question()))
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question(self.articles_number)))

old_questions = list(map(lambda q: q.strip(), get_old_questions(_session, self.record.datasource)))
guess_msg.append(
Expand Down
4 changes: 2 additions & 2 deletions backend/templates/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ template:
### 请使用语言:{lang} 回答,不需要输出深度思考过程

### 说明:
您的任务是根据给定的表结构,用户问题以及以往用户提问,推测用户接下来可能提问的1-4个问题
您的任务是根据给定的表结构,用户问题以及以往用户提问,推测用户接下来可能提问的1-{articles_number}个问题
请遵循以下规则:
- 推测的问题需要与提供的表结构相关,生成的提问例子如:["查询所有用户数据","使用饼图展示各产品类型的占比","使用折线图展示销售额趋势",...]
- 推测问题如果涉及图形展示,支持的图形类型为:表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie)
Expand All @@ -385,7 +385,7 @@ template:
- 如果用户没有提问且没有以往用户提问,则仅根据提供的表结构推测问题
- 生成的推测问题使用JSON格式返回:
["推测问题1", "推测问题2", "推测问题3", "推测问题4"]
- 最多返回4个你推测出的结果
- 最多返回{articles_number}个你推测出的结果
- 若无法推测,则返回空数据JSON:
[]
- 若你的给出的JSON不是{lang}的,则必须翻译为{lang}
Expand Down
8 changes: 6 additions & 2 deletions frontend/src/api/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,12 @@ export const chatApi = {
predict: (record_id: number | undefined, controller?: AbortController) => {
return request.fetchStream(`/chat/record/${record_id}/predict`, {}, controller)
},
recommendQuestions: (record_id: number | undefined, controller?: AbortController) => {
return request.fetchStream(`/chat/recommend_questions/${record_id}`, {}, controller)
recommendQuestions: (
record_id: number | undefined,
controller?: AbortController,
params: any
) => {
return request.fetchStream(`/chat/recommend_questions/${record_id}${params}`, {}, controller)
},
recentQuestions: (datasource_id?: number): Promise<any> => {
return request.get(`/chat/recent_questions/${datasource_id}`)
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/views/chat/QuickQuestion.vue
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const recommendQuestionRef = ref()
const recentQuestionRef = ref()
const popoverRef = ref()
const getRecommendQuestions = () => {
recommendQuestionRef.value.getRecommendQuestions()
recommendQuestionRef.value.getRecommendQuestions(10)
}

const retrieveQuestions = () => {
Expand Down
5 changes: 3 additions & 2 deletions frontend/src/views/chat/RecommendQuestion.vue
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ function clickQuestion(question: string): void {

const stopFlag = ref(false)

async function getRecommendQuestions() {
async function getRecommendQuestions(articles_number: number) {
stopFlag.value = false
loading.value = true
try {
const controller: AbortController = new AbortController()
const response = await chatApi.recommendQuestions(props.recordId, controller)
const params = articles_number ? '?articles_number=' + articles_number : ''
const response = await chatApi.recommendQuestions(props.recordId, controller, params)
const reader = response.body.getReader()
const decoder = new TextDecoder('utf-8')

Expand Down