Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def get_chat_record_by_id(session: SessionDep, record_id: int):
engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by)
return record

def get_chat(session: SessionDep, chat_id: int) -> Chat:
statement = select(Chat).where(Chat.id == chat_id)
chat = session.exec(statement).scalars().first()
return chat

def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]:
oid = current_user.oid if current_user.oid is not None else 1
Expand Down Expand Up @@ -57,6 +61,7 @@ def rename_chat(session: SessionDep, rename_object: RenameChat) -> str:
raise Exception(f"Chat with id {rename_object.id} not found")

chat.brief = rename_object.brief.strip()[:20]
chat.brief_generate = rename_object.brief_generate
session.add(chat)
session.flush()
session.refresh(chat)
Expand Down Expand Up @@ -340,6 +345,13 @@ def format_record(record: ChatRecordResult):

return _dict

def get_chat_brief_generate(session: SessionDep, chat_id: int):
chat = get_chat(session=session,chat_id=chat_id)
if chat is not None and chat.brief_generate is not None:
return chat.brief_generate
else:
return False


def list_generate_sql_logs(session: SessionDep, chart_id: int) -> List[ChatLog]:
stmt = select(ChatLog).where(
Expand Down
4 changes: 3 additions & 1 deletion backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class Chat(SQLModel, table=True):
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
engine_type: str = Field(max_length=64)
origin: Optional[int] = Field(
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
brief_generate: bool = Field(default=False)


class ChatRecord(SQLModel, table=True):
Expand Down Expand Up @@ -149,6 +150,7 @@ class CreateChat(BaseModel):
class RenameChat(BaseModel):
id: int = None
brief: str = ''
brief_generate: bool = True


class ChatInfo(BaseModel):
Expand Down
20 changes: 11 additions & 9 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
save_select_datasource_answer, save_recommend_question_answer, \
get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
get_last_execute_sql_error, format_json_data, format_chart_fields
get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
ChatFinishStep, AxisObj
from apps.data_training.curd.data_training import get_training_template
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id)
self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id)

self.change_title = len(self.generate_sql_logs) == 0
self.change_title = not get_chat_brief_generate(session=session, chat_id=chat_id)

chat_question.lang = get_lang_name(current_user.language)

Expand Down Expand Up @@ -528,7 +528,8 @@ def select_datasource(self, _session: Session):
def generate_sql(self, _session: Session):
# append current question
self.sql_message.append(HumanMessage(
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),change_title = self.change_title)))
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
change_title=self.change_title)))

self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=_session,
ai_modal_id=self.chat_question.ai_modal_id,
Expand Down Expand Up @@ -997,11 +998,13 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
# return title
if self.change_title:
llm_brief = self.get_brief_from_sql_answer(full_sql_text)
if (llm_brief and llm_brief != '') or (self.chat_question.question and self.chat_question.question.strip() != ''):
save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[:20]
llm_brief_generated = bool(llm_brief)
if llm_brief_generated or (self.chat_question.question and self.chat_question.question.strip() != ''):
save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[
:20]
brief = rename_chat(session=_session,
rename_object=RenameChat(id=self.get_record().chat_id,
brief=save_brief))
brief=save_brief, brief_generate=llm_brief_generated))
if in_chat:
yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n'
if not stream:
Expand Down Expand Up @@ -1084,7 +1087,8 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
for field in result.get('fields'):
_column_list.append(AxisObj(name=field, value=field))

md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, result.get('data'))
md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list,
result.get('data'))

# data, _fields_list, col_formats = self.format_pd_data(_column_list, result.get('data'))

Expand Down Expand Up @@ -1203,8 +1207,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
self.finish(_session)
session_maker.remove()



def run_recommend_questions_task_async(self):
self.future = executor.submit(self.run_recommend_questions_task_cache)

Expand Down