Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import asyncio
import io
import traceback
from typing import Optional
from typing import Optional, List

import orjson
import pandas as pd
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Path
from fastapi.responses import StreamingResponse
from sqlalchemy import and_, select

from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
format_json_data, format_json_list_data, get_chart_config, list_recent_questions
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \
ChatInfo, Chat
from apps.chat.task.llm import LLMService
from apps.swagger.i18n import PLACEHOLDER_PREFIX
from apps.system.schemas.permission import SqlbotPermission, require_permissions
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans
from common.utils.command_utils import parse_quick_command
Expand All @@ -22,12 +24,12 @@
router = APIRouter(tags=["Data Q&A"], prefix="/chat")


@router.get("/list")
@router.get("/list", response_model=List[Chat], summary=f"{PLACEHOLDER_PREFIX}get_chat_list")
async def chats(session: SessionDep, current_user: CurrentUser):
return list_chats(session, current_user)


@router.get("/{chart_id}")
@router.get("/{chart_id}", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}get_chat")
async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant,
trans: Trans):
def inner():
Expand All @@ -37,7 +39,7 @@ def inner():
return await asyncio.to_thread(inner)


@router.get("/{chart_id}/with_data")
@router.get("/{chart_id}/with_data", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}get_chat_with_data")
async def get_chat_with_data(session: SessionDep, current_user: CurrentUser, chart_id: int,
current_assistant: CurrentAssistant):
def inner():
Expand All @@ -47,7 +49,7 @@ def inner():
return await asyncio.to_thread(inner)


@router.get("/record/{chat_record_id}/data")
@router.get("/record/{chat_record_id}/data", summary=f"{PLACEHOLDER_PREFIX}get_chart_data")
async def chat_record_data(session: SessionDep, chat_record_id: int):
def inner():
data = get_chat_chart_data(chat_record_id=chat_record_id, session=session)
Expand All @@ -56,7 +58,7 @@ def inner():
return await asyncio.to_thread(inner)


@router.get("/record/{chat_record_id}/predict_data")
@router.get("/record/{chat_record_id}/predict_data", summary=f"{PLACEHOLDER_PREFIX}get_chart_predict_data")
async def chat_predict_data(session: SessionDep, chat_record_id: int):
def inner():
data = get_chat_predict_data(chat_record_id=chat_record_id, session=session)
Expand All @@ -65,7 +67,7 @@ def inner():
return await asyncio.to_thread(inner)


@router.post("/rename")
@router.post("/rename", response_model=str, summary=f"{PLACEHOLDER_PREFIX}rename_chat")
async def rename(session: SessionDep, chat: RenameChat):
try:
return rename_chat(session=session, rename_object=chat)
Expand All @@ -76,7 +78,7 @@ async def rename(session: SessionDep, chat: RenameChat):
)


@router.delete("/{chart_id}")
@router.delete("/{chart_id}", response_model=str, summary=f"{PLACEHOLDER_PREFIX}delete_chat")
async def delete(session: SessionDep, chart_id: int):
try:
return delete_chat(session=session, chart_id=chart_id)
Expand All @@ -87,7 +89,7 @@ async def delete(session: SessionDep, chart_id: int):
)


@router.post("/start")
@router.post("/start", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}start_chat")
@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="create_chat_obj.datasource"))
async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat):
try:
Expand All @@ -99,7 +101,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
)


@router.post("/assistant/start")
@router.post("/assistant/start", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}assistant_start_chat")
async def start_chat(session: SessionDep, current_user: CurrentUser):
try:
return create_chat(session, current_user, CreateChat(origin=2), False)
Expand All @@ -110,9 +112,9 @@ async def start_chat(session: SessionDep, current_user: CurrentUser):
)


@router.post("/recommend_questions/{chat_record_id}")
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
current_assistant: CurrentAssistant, articles_number: Optional[int] = 4):
@router.post("/recommend_questions/{chat_record_id}", summary=f"{PLACEHOLDER_PREFIX}ask_recommend_questions")
async def ask_recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
current_assistant: CurrentAssistant, articles_number: Optional[int] = 4):
def _return_empty():
yield 'data:' + orjson.dumps({'content': '[]', 'type': 'recommended_question'}).decode() + '\n\n'

Expand All @@ -139,9 +141,11 @@ def _err(_e: Exception):
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")


@router.get("/recent_questions/{datasource_id}")
@router.get("/recent_questions/{datasource_id}", response_model=List[str],
summary=f"{PLACEHOLDER_PREFIX}get_recommend_questions")
@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="datasource_id"))
async def recommend_questions(session: SessionDep, current_user: CurrentUser, datasource_id: int):
async def recommend_questions(session: SessionDep, current_user: CurrentUser,
datasource_id: int = Path(..., description=f"{PLACEHOLDER_PREFIX}ds_id")):
return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id)


Expand All @@ -158,7 +162,7 @@ def find_base_question(record_id: int, session: SessionDep):
return rec_question


@router.post("/question")
@router.post("/question", summary=f"{PLACEHOLDER_PREFIX}ask_question")
@require_permissions(permission=SqlbotPermission(type='chat', keyExpression="request_question.chat_id"))
async def question_answer(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
current_assistant: CurrentAssistant):
Expand Down Expand Up @@ -255,10 +259,10 @@ def _err(_e: Exception):
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")


@router.post("/record/{chat_record_id}/{action_type}")
async def analysis_or_predict_question(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
action_type: str,
current_assistant: CurrentAssistant):
@router.post("/record/{chat_record_id}/{action_type}", summary=f"{PLACEHOLDER_PREFIX}analysis_or_predict")
async def analysis_or_predict_question(session: SessionDep, current_user: CurrentUser,
current_assistant: CurrentAssistant, chat_record_id: int,
action_type: str = Path(..., description=f"{PLACEHOLDER_PREFIX}analysis_or_predict_action_type")):
return await analysis_or_predict(session, current_user, chat_record_id, action_type, current_assistant)


Expand Down Expand Up @@ -302,7 +306,7 @@ def _err(_e: Exception):
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")


@router.get("/record/{chat_record_id}/excel/export")
@router.get("/record/{chat_record_id}/excel/export", summary=f"{PLACEHOLDER_PREFIX}export_chart_data")
async def export_excel(session: SessionDep, chat_record_id: int, trans: Trans):
chat_record = session.get(ChatRecord, chat_record_id)
if not chat_record:
Expand Down
23 changes: 13 additions & 10 deletions backend/apps/data_training/api/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \
enable_training, get_all_data_training, batch_create_training
from apps.data_training.models.data_training_model import DataTrainingInfo
from apps.swagger.i18n import PLACEHOLDER_PREFIX
from common.core.config import settings
from common.core.deps import SessionDep, CurrentUser, Trans
from common.utils.data_format import DataFormat
from common.utils.excel import get_excel_column_count

router = APIRouter(tags=["DataTraining"], prefix="/system/data-training")
router = APIRouter(tags=["SQL Examples"], prefix="/system/data-training")


@router.get("/page/{current_page}/{page_size}")
@router.get("/page/{current_page}/{page_size}", summary=f"{PLACEHOLDER_PREFIX}get_dt_page")
async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int,
question: Optional[str] = Query(None, description="搜索问题(可选)")):
current_page, page_size, total_count, total_pages, _list = page_data_training(session, current_page, page_size,
Expand All @@ -38,7 +39,7 @@ async def pager(session: SessionDep, current_user: CurrentUser, current_page: in
}


@router.put("")
@router.put("", response_model=int, summary=f"{PLACEHOLDER_PREFIX}create_or_update_dt")
async def create_or_update(session: SessionDep, current_user: CurrentUser, trans: Trans, info: DataTrainingInfo):
oid = current_user.oid
if info.id:
Expand All @@ -47,17 +48,17 @@ async def create_or_update(session: SessionDep, current_user: CurrentUser, trans
return create_training(session, info, oid, trans)


@router.delete("")
@router.delete("", summary=f"{PLACEHOLDER_PREFIX}delete_dt")
async def delete(session: SessionDep, id_list: list[int]):
delete_training(session, id_list)


@router.get("/{id}/enable/{enabled}")
@router.get("/{id}/enable/{enabled}", summary=f"{PLACEHOLDER_PREFIX}enable_dt")
async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans):
enable_training(session, id, enabled, trans)


@router.get("/export")
@router.get("/export", summary=f"{PLACEHOLDER_PREFIX}export_dt")
async def export_excel(session: SessionDep, trans: Trans, current_user: CurrentUser,
question: Optional[str] = Query(None, description="搜索术语(可选)")):
def inner():
Expand Down Expand Up @@ -98,7 +99,7 @@ def inner():
return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")


@router.get("/template")
@router.get("/template", summary=f"{PLACEHOLDER_PREFIX}excel_template_dt")
async def excel_template(trans: Trans, current_user: CurrentUser):
def inner():
data_list = []
Expand All @@ -113,10 +114,12 @@ def inner():
fields = []
fields.append(AxisObj(name=trans('i18n_data_training.problem_description_template'), value='question'))
fields.append(AxisObj(name=trans('i18n_data_training.sample_sql_template'), value='description'))
fields.append(AxisObj(name=trans('i18n_data_training.effective_data_sources_template'), value='datasource_name'))
fields.append(
AxisObj(name=trans('i18n_data_training.effective_data_sources_template'), value='datasource_name'))
if current_user.oid == 1:
fields.append(
AxisObj(name=trans('i18n_data_training.advanced_application_template'), value='advanced_application_name'))
AxisObj(name=trans('i18n_data_training.advanced_application_template'),
value='advanced_application_name'))

md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list)

Expand Down Expand Up @@ -144,7 +147,7 @@ def inner():
session_maker = scoped_session(sessionmaker(bind=engine, class_=Session))


@router.post("/uploadExcel")
@router.post("/uploadExcel", summary=f"{PLACEHOLDER_PREFIX}upload_excel_dt")
async def upload_excel(trans: Trans, current_user: CurrentUser, file: UploadFile = File(...)):
ALLOWED_EXTENSIONS = {"xlsx", "xls"}
if not file.filename.lower().endswith(tuple(ALLOWED_EXTENSIONS)):
Expand Down
3 changes: 2 additions & 1 deletion backend/apps/settings/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi import APIRouter
from fastapi.responses import FileResponse

from apps.swagger.i18n import PLACEHOLDER_PREFIX
from common.core.config import settings
from common.core.file import FileRequest

Expand All @@ -12,7 +13,7 @@
path = settings.EXCEL_PATH


@router.post("/download-fail-info")
@router.post("/download-fail-info", summary=f"{PLACEHOLDER_PREFIX}download-fail-info")
async def download_excel(req: FileRequest):
"""
根据文件路径下载 Excel 文件
Expand Down
16 changes: 16 additions & 0 deletions backend/apps/swagger/i18n.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def load_translation(lang: str) -> Dict[str, str]:

# group tags
tags_metadata = [
{
"name": "Data Q&A",
"description": f"{PLACEHOLDER_PREFIX}data_qa"
},
{
"name": "Datasource",
"description": f"{PLACEHOLDER_PREFIX}ds_api"
Expand Down Expand Up @@ -77,6 +81,18 @@ def load_translation(lang: str) -> Dict[str, str]:
"name": "Data Permission",
"description": f"{PLACEHOLDER_PREFIX}per_api"
},
{
"name": "SQL Examples",
"description": f"{PLACEHOLDER_PREFIX}data_training_api"
},
{
"name": "Terminology",
"description": f"{PLACEHOLDER_PREFIX}terminology_api"
},
{
"name": "CustomPrompt",
"description": f"{PLACEHOLDER_PREFIX}custom_prompt_api"
},

]

Expand Down
47 changes: 46 additions & 1 deletion backend/apps/swagger/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,50 @@

"tr_api": "Table Relation",
"tr_save": "Save Table Relation",
"tr_get": "Get Table Relation"
"tr_get": "Get Table Relation",

"data_qa": "Data Q&A",
"get_chat_list": "Get Chat List",
"get_chat": "Get Chat Details",
"get_chat_with_data": "Get Chat Details (With Data)",
"get_chart_data": "Get Chart Data",
"get_chart_predict_data": "Get Chart Prediction Data",
"rename_chat": "Rename Chat",
"delete_chat": "Delete Chat",
"start_chat": "Create Chat",
"assistant_start_chat": "Assistant Create Chat",
"ask_recommend_questions": "AI Get Recommended Questions",
"get_recommend_questions": "Query Recommended Questions",
"ask_question": "Ask Question",
"analysis_or_predict": "Analyze Data / Predict Data",
"export_chart_data": "Export Chart Data",
"analysis_or_predict_action_type": "Type, allowed values: analysis | predict",

"download-fail-info": "Download Error Information",

"data_training_api": "SQL Examples",
"get_dt_page": "Pagination Query for SQL Examples",
"create_or_update_dt": "Create/Update SQL Example",
"delete_dt": "Delete SQL Example",
"enable_dt": "Enable/Disable",
"export_dt": "Export SQL Examples",
"excel_template_dt": "Download Template",
"upload_excel_dt": "Import SQL Examples",

"terminology_api": "Terminology",
"get_term_page": "Pagination Query for Terms",
"create_or_update_term": "Create/Update Term",
"delete_term": "Delete Term",
"enable_term": "Enable/Disable",
"export_term": "Export Terms",
"excel_template_term": "Download Template",
"upload_term": "Import Terms",

"custom_prompt_api": "Custom Prompts",
"custom_prompt_page": "Pagination Query for Custom Prompts",
"create_or_update_custom_prompt": "Create/Update Custom Prompt",
"delete_custom_prompt": "Delete Custom Prompt",
"export_custom_prompt": "Export Custom Prompts",
"excel_template_custom_prompt": "Download Template",
"upload_custom_prompt": "Import Custom Prompts"
}
47 changes: 46 additions & 1 deletion backend/apps/swagger/locales/zh.json
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,50 @@

"tr_api": "表关联关系",
"tr_save": "保存关联关系",
"tr_get": "查询关联关系"
"tr_get": "查询关联关系",

"data_qa": "智能问数",
"get_chat_list": "获取对话列表",
"get_chat": "获取对话详情",
"get_chat_with_data": "获取对话详情(带数据)",
"get_chart_data": "获取图表数据",
"get_chart_predict_data": "获取图表预测数据",
"rename_chat": "重命名对话",
"delete_chat": "删除对话",
"start_chat": "创建对话",
"assistant_start_chat": "小助手创建对话",
"ask_recommend_questions": "AI获取推荐提问",
"get_recommend_questions": "查询推荐提问",
"ask_question": "提问",
"analysis_or_predict": "分析数据/预测数据",
"export_chart_data": "导出图表数据",
"analysis_or_predict_action_type": "类型,可传入值为:analysis | predict",

"download-fail-info": "下载错误信息",

"data_training_api": "SQL示例",
"get_dt_page": "分页查询SQL示例",
"create_or_update_dt": "创建/更新SQL示例",
"delete_dt": "删除SQL示例",
"enable_dt": "启用/禁用",
"export_dt": "导出SQL示例",
"excel_template_dt": "下载模板",
"upload_excel_dt": "导入SQL示例",

"terminology_api": "术语",
"get_term_page": "分页查询术语",
"create_or_update_term": "创建/更新术语",
"delete_term": "删除术语",
"enable_term": "启用/禁用",
"export_term": "导出术语",
"excel_template_term": "下载模板",
"upload_term": "导入术语",

"custom_prompt_api": "自定义提示词",
"custom_prompt_page": "分页查询自定义提示词",
"create_or_update_custom_prompt": "创建/更新自定义提示词",
"delete_custom_prompt": "删除自定义提示词",
"export_custom_prompt": "导出自定义提示词",
"excel_template_custom_prompt": "下载模板",
"upload_custom_prompt": "导入自定义提示词"
}
Loading