Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 84 additions & 27 deletions backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from fastapi import APIRouter, HTTPException, Path
from fastapi.responses import StreamingResponse
from sqlalchemy import and_, select
from starlette.responses import JSONResponse

from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
format_json_data, format_json_list_data, get_chart_config, list_recent_questions
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \
ChatInfo, Chat
ChatInfo, Chat, ChatFinishStep
from apps.chat.task.llm import LLMService
from apps.swagger.i18n import PLACEHOLDER_PREFIX
from apps.system.schemas.permission import SqlbotPermission, require_permissions
Expand Down Expand Up @@ -166,11 +167,18 @@ def find_base_question(record_id: int, session: SessionDep):
@require_permissions(permission=SqlbotPermission(type='chat', keyExpression="request_question.chat_id"))
async def question_answer(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
current_assistant: CurrentAssistant):
return await question_answer_inner(session, current_user, request_question, current_assistant, embedding=True)


async def question_answer_inner(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
current_assistant: Optional[CurrentAssistant] = None, in_chat: bool = True,
stream: bool = True,
finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART, embedding: bool = False):
try:
command, text_before_command, record_id, warning_info = parse_quick_command(request_question.question)
if command:
# todo 暂不支持分析和预测,需要改造前端
if command == QuickCommand.ANALYSIS or command == QuickCommand.PREDICT_DATA:
# todo 对话界面下,暂不支持分析和预测,需要改造前端
if in_chat and (command == QuickCommand.ANALYSIS or command == QuickCommand.PREDICT_DATA):
raise Exception(f'Command: {command.value} temporary not supported')

if record_id is not None:
Expand Down Expand Up @@ -221,53 +229,83 @@ async def question_answer(session: SessionDep, current_user: CurrentUser, reques
if command == QuickCommand.REGENERATE:
request_question.question = text_before_command
request_question.regenerate_record_id = rec_id
return await stream_sql(session, current_user, request_question, current_assistant)
return await stream_sql(session, current_user, request_question, current_assistant, in_chat, stream,
finish_step, embedding)

elif command == QuickCommand.ANALYSIS:
return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant)
return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant, in_chat, stream)

elif command == QuickCommand.PREDICT_DATA:
return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant)
return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant, in_chat, stream)
else:
raise Exception(f'Unknown command: {command.value}')
else:
return await stream_sql(session, current_user, request_question, current_assistant)
return await stream_sql(session, current_user, request_question, current_assistant, in_chat, stream,
finish_step, embedding)
except Exception as e:
traceback.print_exc()

def _err(_e: Exception):
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
if stream:
def _err(_e: Exception):
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'

return StreamingResponse(_err(e), media_type="text/event-stream")
return StreamingResponse(_err(e), media_type="text/event-stream")
else:
return JSONResponse(
content={'message': str(e)},
status_code=500,
)


async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
current_assistant: CurrentAssistant):
current_assistant: Optional[CurrentAssistant] = None, in_chat: bool = True, stream: bool = True,
finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART, embedding: bool = False):
try:
llm_service = await LLMService.create(session, current_user, request_question, current_assistant,
embedding=True)
embedding=embedding)
llm_service.init_record(session=session)
llm_service.run_task_async()
llm_service.run_task_async(in_chat=in_chat, stream=stream, finish_step=finish_step)
except Exception as e:
traceback.print_exc()

def _err(_e: Exception):
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'

return StreamingResponse(_err(e), media_type="text/event-stream")
if stream:
def _err(_e: Exception):
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'

return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
return StreamingResponse(_err(e), media_type="text/event-stream")
else:
return JSONResponse(
content={'message': str(e)},
status_code=500,
)
if stream:
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
else:
res = llm_service.await_result()
raw_data = {}
for chunk in res:
if chunk:
raw_data = chunk
status_code = 200
if not raw_data.get('success'):
status_code = 500

return JSONResponse(
content=raw_data,
status_code=status_code,
)


@router.post("/record/{chat_record_id}/{action_type}", summary=f"{PLACEHOLDER_PREFIX}analysis_or_predict")
async def analysis_or_predict_question(session: SessionDep, current_user: CurrentUser,
current_assistant: CurrentAssistant, chat_record_id: int,
action_type: str = Path(..., description=f"{PLACEHOLDER_PREFIX}analysis_or_predict_action_type")):
action_type: str = Path(...,
description=f"{PLACEHOLDER_PREFIX}analysis_or_predict_action_type")):
return await analysis_or_predict(session, current_user, chat_record_id, action_type, current_assistant)


async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str,
current_assistant: CurrentAssistant):
current_assistant: CurrentAssistant, in_chat: bool = True, stream: bool = True):
try:
if action_type != 'analysis' and action_type != 'predict':
raise Exception(f"Type {action_type} Not Found")
Expand All @@ -294,16 +332,35 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question)

llm_service = await LLMService.create(session, current_user, request_question, current_assistant)
llm_service.run_analysis_or_predict_task_async(session, action_type, record)
llm_service.run_analysis_or_predict_task_async(session, action_type, record, in_chat, stream)
except Exception as e:
traceback.print_exc()
if stream:
def _err(_e: Exception):
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'

def _err(_e: Exception):
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'

return StreamingResponse(_err(e), media_type="text/event-stream")

return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
return StreamingResponse(_err(e), media_type="text/event-stream")
else:
return JSONResponse(
content={'message': str(e)},
status_code=500,
)
if stream:
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
else:
res = llm_service.await_result()
raw_data = {}
for chunk in res:
if chunk:
raw_data = chunk
status_code = 200
if not raw_data.get('success'):
status_code = 500

return JSONResponse(
content=raw_data,
status_code=status_code,
)


@router.get("/record/{chat_record_id}/excel/export", summary=f"{PLACEHOLDER_PREFIX}export_chart_data")
Expand Down
11 changes: 11 additions & 0 deletions backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,17 @@ def format_json_list_data(origin_data: list[dict]):
return data


def get_chat_chart_config(session: SessionDep, chat_record_id: int):
stmt = select(ChatRecord.chart).where(and_(ChatRecord.id == chat_record_id))
res = session.execute(stmt)
for row in res:
try:
return orjson.loads(row.data)
except Exception:
pass
return {}


def get_chat_chart_data(session: SessionDep, chat_record_id: int):
stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chat_record_id))
res = session.execute(stmt)
Expand Down
Loading