Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# from apps.system.models.user import SQLModel # noqa
# from apps.settings.models.setting_models import SQLModel
# from apps.chat.models.chat_model import SQLModel
from apps.chat.models.chat_model import SQLModel
from apps.terminology.models.terminology_model import SQLModel
#from apps.custom_prompt.models.custom_prompt_model import SQLModel
from apps.data_training.models.data_training_model import SQLModel
Expand Down
29 changes: 29 additions & 0 deletions backend/alembic/versions/054_update_chat_record_dll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""054_update_chat_record_dll

Revision ID: 24e961f6326b
Revises: 5755c0b95839
Create Date: 2025-12-04 15:51:42.900778

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '24e961f6326b'
down_revision = '5755c0b95839'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('chat_record', sa.Column('regenerate_record_id', sa.BigInteger(), nullable=True))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('chat_record', 'regenerate_record_id')
# ### end Alembic commands ###
110 changes: 98 additions & 12 deletions backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
format_json_data, format_json_list_data, get_chart_config, list_recent_questions
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand
from apps.chat.task.llm import LLMService
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans
from common.utils.command_utils import parse_quick_command
from common.utils.data_format import DataFormat

router = APIRouter(tags=["Data Q&A"], prefix="/chat")
Expand Down Expand Up @@ -141,20 +142,99 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, da
return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id)


def find_base_question(record_id: int, session: SessionDep):
stmt = select(ChatRecord.question, ChatRecord.regenerate_record_id).where(
and_(ChatRecord.id == record_id))
_record = session.execute(stmt).fetchone()
if not _record:
raise Exception(f'Cannot find base chat record')
rec_question, rec_regenerate_record_id = _record
if rec_regenerate_record_id:
return find_base_question(rec_regenerate_record_id, session)
else:
return rec_question


@router.post("/question")
async def question_answer(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
current_assistant: CurrentAssistant):
try:
command, text_before_command, record_id, warning_info = parse_quick_command(request_question.question)
if command:
# todo 暂不支持分析和预测,需要改造前端
if command == QuickCommand.ANALYSIS or command == QuickCommand.PREDICT_DATA:
raise Exception(f'Command: {command.value} temporary not supported')

if record_id is not None:
# 排除analysis和predict
stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.analysis_record_id,
ChatRecord.predict_record_id, ChatRecord.regenerate_record_id,
ChatRecord.first_chat).where(
and_(ChatRecord.id == record_id))
_record = session.execute(stmt).fetchone()
if not _record:
raise Exception(f'Record id: {record_id} does not exist')

rec_id, rec_chat_id, rec_analysis_record_id, rec_predict_record_id, rec_regenerate_record_id, rec_first_chat = _record

if rec_chat_id != request_question.chat_id:
raise Exception(f'Record id: {record_id} does not belong to this chat')
if rec_first_chat:
raise Exception(f'Record id: {record_id} does not support this operation')

if command == QuickCommand.REGENERATE:
if rec_analysis_record_id:
raise Exception('Analysis record does not support this operation')
if rec_predict_record_id:
raise Exception('Predict data record does not support this operation')

else: # get last record id
stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.regenerate_record_id).where(
and_(ChatRecord.chat_id == request_question.chat_id,
ChatRecord.first_chat == False,
ChatRecord.analysis_record_id.is_(None),
ChatRecord.predict_record_id.is_(None))).order_by(
ChatRecord.create_time.desc()).limit(1)
_record = session.execute(stmt).fetchone()

if not _record:
raise Exception(f'You have not ask any question')

rec_id, rec_chat_id, rec_regenerate_record_id = _record

# 没有指定的,就查询上一个
if not rec_regenerate_record_id:
rec_regenerate_record_id = rec_id

# 针对已经是重新生成的提问,需要找到原来的提问是什么
base_question_text = find_base_question(rec_regenerate_record_id, session)
text_before_command = text_before_command + ("\n" if text_before_command else "") + base_question_text

if command == QuickCommand.REGENERATE:
request_question.question = text_before_command
request_question.regenerate_record_id = rec_id
return await stream_sql(session, current_user, request_question, current_assistant)

elif command == QuickCommand.ANALYSIS:
return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant)

elif command == QuickCommand.PREDICT_DATA:
return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant)
else:
raise Exception(f'Unknown command: {command.value}')
else:
return await stream_sql(session, current_user, request_question, current_assistant)
except Exception as e:
traceback.print_exc()

def _err(_e: Exception):
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'

return StreamingResponse(_err(e), media_type="text/event-stream")


async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
current_assistant: CurrentAssistant):
"""Stream SQL analysis results

Args:
session: Database session
current_user: CurrentUser
request_question: User question model

Returns:
Streaming response with analysis results
"""

try:
llm_service = await LLMService.create(session, current_user, request_question, current_assistant,
embedding=True)
Expand All @@ -172,6 +252,12 @@ def _err(_e: Exception):


@router.post("/record/{chat_record_id}/{action_type}")
async def analysis_or_predict_question(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
action_type: str,
current_assistant: CurrentAssistant):
return await analysis_or_predict(session, current_user, chat_record_id, action_type, current_assistant)


async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str,
current_assistant: CurrentAssistant):
try:
Expand Down
22 changes: 16 additions & 6 deletions backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import datetime
from typing import List
from sqlalchemy import desc, func

import orjson
import sqlparse
from sqlalchemy import and_, select, update
from sqlalchemy import desc, func
from sqlalchemy.orm import aliased

from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion, ChatLog, \
TypeEnum, OperationEnum, ChatRecordResult
from apps.datasource.crud.recommended_problem import get_datasource_recommended, get_datasource_recommended_chart
from apps.datasource.models.datasource import CoreDatasource, DsRecommendedProblem
from apps.datasource.crud.recommended_problem import get_datasource_recommended_chart
from apps.datasource.models.datasource import CoreDatasource
from apps.system.crud.assistant import AssistantOutDsFactory
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans
from common.utils.utils import extract_nested_json
Expand All @@ -28,11 +28,13 @@ def get_chat_record_by_id(session: SessionDep, record_id: int):
engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by)
return record


def get_chat(session: SessionDep, chat_id: int) -> Chat:
statement = select(Chat).where(Chat.id == chat_id)
chat = session.exec(statement).scalars().first()
return chat


def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]:
oid = current_user.oid if current_user.oid is not None else 1
chart_list = session.query(Chat).filter(and_(Chat.create_by == current_user.id, Chat.oid == oid)).order_by(
Expand All @@ -57,6 +59,7 @@ def list_recent_questions(session: SessionDep, current_user: CurrentUser, dataso
)
return [record[0] for record in chat_records] if chat_records else []


def rename_chat(session: SessionDep, rename_object: RenameChat) -> str:
chat = session.get(Chat, rename_object.id)
if not chat:
Expand Down Expand Up @@ -191,7 +194,8 @@ def get_chat_with_records_with_data(session: SessionDep, chart_id: int, current_


def get_chat_with_records(session: SessionDep, chart_id: int, current_user: CurrentUser,
current_assistant: CurrentAssistant, with_data: bool = False,trans: Trans = None) -> ChatInfo:
current_assistant: CurrentAssistant, with_data: bool = False,
trans: Trans = None) -> ChatInfo:
chat = session.get(Chat, chart_id)
if not chat:
raise Exception(f"Chat with id {chart_id} not found")
Expand All @@ -200,7 +204,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr

if current_assistant and current_assistant.type in dynamic_ds_types:
out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant)
ds = out_ds_instance.get_ds(chat.datasource,trans)
ds = out_ds_instance.get_ds(chat.datasource, trans)
else:
ds = session.get(CoreDatasource, chat.datasource) if chat.datasource else None

Expand All @@ -221,6 +225,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql,
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id,
ChatRecord.regenerate_record_id,
ChatRecord.recommended_question, ChatRecord.first_chat,
ChatRecord.finish, ChatRecord.error,
sql_alias_log.reasoning_content.label('sql_reasoning_content'),
Expand All @@ -247,6 +252,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql,
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id,
ChatRecord.regenerate_record_id,
ChatRecord.recommended_question, ChatRecord.first_chat,
ChatRecord.finish, ChatRecord.error, ChatRecord.data, ChatRecord.predict_data).where(
and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(
Expand All @@ -264,6 +270,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
analysis=row.analysis, predict=row.predict,
datasource_select_answer=row.datasource_select_answer,
analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id,
regenerate_record_id=row.regenerate_record_id,
recommended_question=row.recommended_question, first_chat=row.first_chat,
finish=row.finish, error=row.error,
sql_reasoning_content=row.sql_reasoning_content,
Expand All @@ -280,6 +287,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
analysis=row.analysis, predict=row.predict,
datasource_select_answer=row.datasource_select_answer,
analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id,
regenerate_record_id=row.regenerate_record_id,
recommended_question=row.recommended_question, first_chat=row.first_chat,
finish=row.finish, error=row.error, data=row.data, predict_data=row.predict_data))

Expand Down Expand Up @@ -347,8 +355,9 @@ def format_record(record: ChatRecordResult):

return _dict


def get_chat_brief_generate(session: SessionDep, chat_id: int):
chat = get_chat(session=session,chat_id=chat_id)
chat = get_chat(session=session, chat_id=chat_id)
if chat is not None and chat.brief_generate is not None:
return chat.brief_generate
else:
Expand Down Expand Up @@ -468,6 +477,7 @@ def save_question(session: SessionDep, current_user: CurrentUser, question: Chat
record.datasource = chat.datasource
record.engine_type = chat.engine_type
record.ai_modal_id = question.ai_modal_id
record.regenerate_record_id = question.regenerate_record_id

result = ChatRecord(**record.model_dump())

Expand Down
16 changes: 14 additions & 2 deletions backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class ChatFinishStep(Enum):
GENERATE_CHART = 3


class QuickCommand(Enum):
REGENERATE = '/regenerate'
ANALYSIS = '/analysis'
PREDICT_DATA = '/predict'


# TODO choose table / check connection / generate description

class ChatLog(SQLModel, table=True):
Expand Down Expand Up @@ -78,7 +84,7 @@ class Chat(SQLModel, table=True):
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
engine_type: str = Field(max_length=64)
origin: Optional[int] = Field(
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
brief_generate: bool = Field(default=False)


Expand Down Expand Up @@ -110,6 +116,7 @@ class ChatRecord(SQLModel, table=True):
error: str = Field(sa_column=Column(Text, nullable=True))
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
regenerate_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))


class ChatRecordResult(BaseModel):
Expand All @@ -134,6 +141,7 @@ class ChatRecordResult(BaseModel):
error: Optional[str] = None
analysis_record_id: Optional[int] = None
predict_record_id: Optional[int] = None
regenerate_record_id: Optional[int] = None
sql_reasoning_content: Optional[str] = None
chart_reasoning_content: Optional[str] = None
analysis_reasoning_content: Optional[str] = None
Expand Down Expand Up @@ -184,6 +192,7 @@ class AiModelQuestion(BaseModel):
data_training: str = ""
custom_prompt: str = ""
error_msg: str = ""
regenerate_record_id: Optional[int] = None

def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
_sql_template = get_sql_example_template(db_type)
Expand Down Expand Up @@ -213,7 +222,10 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
example_answer_3=_example_answer_3)

def sql_user_question(self, current_time: str, change_title: bool):
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
_question = self.question
if self.regenerate_record_id:
_question = get_sql_template()['regenerate_hint'] + self.question
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=_question,
rule=self.rule, current_time=current_time, error_msg=self.error_msg,
change_title=change_title)

Expand Down
10 changes: 10 additions & 0 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ def is_running(self, timeout=0.5):
def init_messages(self):
last_sql_messages: List[dict[str, Any]] = self.generate_sql_logs[-1].messages if len(
self.generate_sql_logs) > 0 else []
if self.chat_question.regenerate_record_id:
# filter record before regenerate_record_id
_temp_log = next(
filter(lambda obj: obj.pid == self.chat_question.regenerate_record_id, self.generate_sql_logs), None)
last_sql_messages: List[dict[str, Any]] = _temp_log.messages if _temp_log else []

# todo maybe can configure
count_limit = 0 - base_message_count_limit
Expand Down Expand Up @@ -947,6 +952,11 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
# return id
if in_chat:
yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n'
if self.get_record().regenerate_record_id:
yield 'data:' + orjson.dumps({'type': 'regenerate_record_id',
'regenerate_record_id': self.get_record().regenerate_record_id}).decode() + '\n\n'
yield 'data:' + orjson.dumps(
{'type': 'question', 'question': self.get_record().question}).decode() + '\n\n'
if not stream:
json_result['record_id'] = self.get_record().id

Expand Down
Loading