Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
format_json_data, format_json_list_data
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData
format_json_data, format_json_list_data, get_chart_config
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj
from apps.chat.task.llm import LLMService
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans

Expand Down Expand Up @@ -42,19 +42,19 @@ def inner():
return await asyncio.to_thread(inner)


@router.get("/record/{chart_record_id}/data")
async def chat_record_data(session: SessionDep, chart_record_id: int):
@router.get("/record/{chat_record_id}/data")
async def chat_record_data(session: SessionDep, chat_record_id: int):
def inner():
data = get_chat_chart_data(chart_record_id=chart_record_id, session=session)
data = get_chat_chart_data(chat_record_id=chat_record_id, session=session)
return format_json_data(data)

return await asyncio.to_thread(inner)


@router.get("/record/{chart_record_id}/predict_data")
async def chat_predict_data(session: SessionDep, chart_record_id: int):
@router.get("/record/{chat_record_id}/predict_data")
async def chat_predict_data(session: SessionDep, chat_record_id: int):
def inner():
data = get_chat_predict_data(chart_record_id=chart_record_id, session=session)
data = get_chat_predict_data(chat_record_id=chat_record_id, session=session)
return format_json_list_data(data)

return await asyncio.to_thread(inner)
Expand Down Expand Up @@ -203,17 +203,49 @@ def _err(_e: Exception):
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")


@router.post("/excel/export")
async def export_excel(excel_data: ExcelData, trans: Trans):
def inner():
@router.get("/record/{chat_record_id}/excel/export")
async def export_excel(session: SessionDep, chat_record_id: int, trans: Trans):
chat_record = session.get(ChatRecord, chat_record_id)
if not chat_record:
raise HTTPException(
status_code=500,
detail=f"ChatRecord with id {chat_record_id} not found"
)

is_predict_data = chat_record.predict_record_id is not None

_origin_data = format_json_data(get_chat_chart_data(chat_record_id=chat_record_id, session=session))

_base_field = _origin_data.get('fields')
_data = _origin_data.get('data')

if not _data:
raise HTTPException(
status_code=500,
detail=trans("i18n_excel_export.data_is_empty")
)

chart_info = get_chart_config(session, chat_record_id)

if not excel_data.data:
raise HTTPException(
status_code=500,
detail=trans("i18n_excel_export.data_is_empty")
)
_title = chart_info.get('title') if chart_info.get('title') else 'Excel'

fields = []
if chart_info.get('columns') and len(chart_info.get('columns')) > 0:
for column in chart_info.get('columns'):
fields.append(AxisObj(name=column.get('name'), value=column.get('value')))
if chart_info.get('axis'):
for _type in ['x', 'y', 'series']:
if chart_info.get('axis').get(_type):
column = chart_info.get('axis').get(_type)
fields.append(AxisObj(name=column.get('name'), value=column.get('value')))

_predict_data = []
if is_predict_data:
_predict_data = format_json_list_data(get_chat_predict_data(chat_record_id=chat_record_id, session=session))

def inner():

data, _fields_list, col_formats = LLMService.format_pd_data(excel_data.axis, excel_data.data)
data, _fields_list, col_formats = LLMService.format_pd_data(fields, _data + _predict_data)

df = pd.DataFrame(data, columns=_fields_list)

Expand Down
25 changes: 21 additions & 4 deletions backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ def get_chart_config(session: SessionDep, chart_record_id: int):
pass
return {}

def format_chart_fields(chart_info: dict):
fields = []
if chart_info.get('columns') and len(chart_info.get('columns')) > 0:
for column in chart_info.get('columns'):
column_str = column.get('value')
if column.get('value') != column.get('name'):
column_str = column_str + '(' + column.get('name') + ')'
fields.append(column_str)
if chart_info.get('axis'):
for _type in ['x', 'y', 'series']:
if chart_info.get('axis').get(_type):
column = chart_info.get('axis').get(_type)
column_str = column.get('value')
if column.get('value') != column.get('name'):
column_str = column_str + '(' + column.get('name') + ')'
fields.append(column_str)
return fields

def get_last_execute_sql_error(session: SessionDep, chart_id: int):
stmt = select(ChatRecord.error).where(and_(ChatRecord.chat_id == chart_id)).order_by(
Expand Down Expand Up @@ -117,8 +134,8 @@ def format_json_list_data(origin_data: list[dict]):
return data


def get_chat_chart_data(session: SessionDep, chart_record_id: int):
stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chart_record_id))
def get_chat_chart_data(session: SessionDep, chat_record_id: int):
stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chat_record_id))
res = session.execute(stmt)
for row in res:
try:
Expand All @@ -128,8 +145,8 @@ def get_chat_chart_data(session: SessionDep, chart_record_id: int):
return {}


def get_chat_predict_data(session: SessionDep, chart_record_id: int):
stmt = select(ChatRecord.predict_data).where(and_(ChatRecord.id == chart_record_id))
def get_chat_predict_data(session: SessionDep, chat_record_id: int):
stmt = select(ChatRecord.predict_data).where(and_(ChatRecord.id == chat_record_id))
res = session.execute(stmt)
for row in res:
try:
Expand Down
19 changes: 2 additions & 17 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
save_select_datasource_answer, save_recommend_question_answer, \
get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
get_last_execute_sql_error, format_json_data
get_last_execute_sql_error, format_json_data, format_chart_fields
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
ChatFinishStep, AxisObj
from apps.data_training.curd.data_training import get_training_template
Expand Down Expand Up @@ -214,22 +214,7 @@ def set_record(self, record: ChatRecord):

def get_fields_from_chart(self, _session: Session):
chart_info = get_chart_config(_session, self.record.id)
fields = []
if chart_info.get('columns') and len(chart_info.get('columns')) > 0:
for column in chart_info.get('columns'):
column_str = column.get('value')
if column.get('value') != column.get('name'):
column_str = column_str + '(' + column.get('name') + ')'
fields.append(column_str)
if chart_info.get('axis'):
for _type in ['x', 'y', 'series']:
if chart_info.get('axis').get(_type):
column = chart_info.get('axis').get(_type)
column_str = column.get('value')
if column.get('value') != column.get('name'):
column_str = column_str + '(' + column.get('name') + ')'
fields.append(column_str)
return fields
return format_chart_fields(chart_info)

def generate_analysis(self, _session: Session):
fields = self.get_fields_from_chart(_session)
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/api/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ export const chatApi = {
return request.fetchStream(`/chat/recommend_questions/${record_id}`, {}, controller)
},
checkLLMModel: () => request.get('/system/aimodel/default', { requestOptions: { silent: true } }),
export2Excel: (data: any) =>
request.post('/chat/excel/export', data, {
export2Excel: (record_id: number | undefined) =>
request.get(`/chat/record/${record_id}/excel/export`, {
responseType: 'blob',
requestOptions: { customError: true },
}),
Expand Down
6 changes: 5 additions & 1 deletion frontend/src/views/chat/answer/ChartAnswer.vue
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ import BaseAnswer from './BaseAnswer.vue'
import { Chat, chatApi, ChatInfo, type ChatMessage, ChatRecord, questionApi } from '@/api/chat.ts'
import { computed, nextTick, onBeforeUnmount, onMounted, ref } from 'vue'
import ChartBlock from '@/views/chat/chat-block/ChartBlock.vue'

const props = withDefaults(
defineProps<{
recordId?: number
chatList?: Array<ChatInfo>
currentChatId?: number
currentChat?: ChatInfo
Expand All @@ -13,6 +15,7 @@ const props = withDefaults(
reasoningName: 'sql_answer' | 'chart_answer' | Array<'sql_answer' | 'chart_answer'>
}>(),
{
recordId: undefined,
chatList: () => [],
currentChatId: undefined,
currentChat: () => new ChatInfo(),
Expand Down Expand Up @@ -229,6 +232,7 @@ function getChatData(recordId?: number) {
emits('scrollBottom')
})
}

function stop() {
stopFlag.value = true
_loading.value = false
Expand All @@ -250,7 +254,7 @@ defineExpose({ sendMessage, index: () => index.value, stop })

<template>
<BaseAnswer v-if="message" :message="message" :reasoning-name="reasoningName" :loading="_loading">
<ChartBlock style="margin-top: 6px" :message="message" />
<ChartBlock style="margin-top: 6px" :message="message" :record-id="recordId" />
<slot></slot>
<template #tool>
<slot name="tool"></slot>
Expand Down
4 changes: 4 additions & 0 deletions frontend/src/views/chat/answer/PredictAnswer.vue
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ import { chatApi, ChatInfo, type ChatMessage, ChatRecord } from '@/api/chat.ts'
import { computed, nextTick, onBeforeUnmount, onMounted, ref } from 'vue'
import MdComponent from '@/views/chat/component/MdComponent.vue'
import ChartBlock from '@/views/chat/chat-block/ChartBlock.vue'

const props = withDefaults(
defineProps<{
recordId?: number
chatList?: Array<ChatInfo>
currentChatId?: number
currentChat?: ChatInfo
message?: ChatMessage
loading?: boolean
}>(),
{
recordId: undefined,
chatList: () => [],
currentChatId: undefined,
currentChat: () => new ChatInfo(),
Expand Down Expand Up @@ -257,6 +260,7 @@ defineExpose({ sendMessage, index: () => index.value, chatList: () => _chatList,
v-if="message.record?.predict_data?.length > 0 && message.record?.data"
ref="chartBlockRef"
style="margin-top: 12px"
:record-id="recordId"
:message="message"
is-predict
/>
Expand Down
6 changes: 4 additions & 2 deletions frontend/src/views/chat/chat-block/ChartBlock.vue
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ import { chatApi } from '@/api/chat'

const props = withDefaults(
defineProps<{
recordId?: number
message: ChatMessage
isPredict?: boolean
chatType?: ChartTypes
enlarge?: boolean
}>(),
{
recordId: undefined,
isPredict: false,
chatType: undefined,
enlarge: false,
Expand Down Expand Up @@ -240,10 +242,10 @@ function copyText() {
const exportRef = ref()

function exportToExcel() {
if (chartRef.value) {
if (chartRef.value && props.recordId) {
loading.value = true
chatApi
.export2Excel({ ...chartRef.value?.getExcelData(), name: chartObject.value.title })
.export2Excel(props.recordId)
.then((res) => {
const blob = new Blob([res], {
type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/views/chat/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@
:chat-list="chatList"
:current-chat="currentChat"
:current-chat-id="currentChatId"
:record-id="message.record?.id"
:loading="isTyping"
:message="message"
:reasoning-name="['sql_answer', 'chart_answer']"
Expand Down Expand Up @@ -358,6 +359,7 @@
:chat-list="chatList"
:current-chat="currentChat"
:current-chat-id="currentChatId"
:record-id="message.record?.id"
:loading="isTyping"
:message="message"
@scroll-bottom="scrollToBottom"
Expand Down