Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion backend/apps/data_training/api/data_training.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import asyncio
import io
from typing import Optional

import pandas as pd
from fastapi import APIRouter, Query
from fastapi.responses import StreamingResponse

from apps.chat.models.chat_model import AxisObj
from apps.chat.task.llm import LLMService
from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \
enable_training
enable_training, get_all_data_training
from apps.data_training.models.data_training_model import DataTrainingInfo
from common.core.deps import SessionDep, CurrentUser, Trans

Expand Down Expand Up @@ -43,3 +49,44 @@ async def delete(session: SessionDep, id_list: list[int]):
@router.get("/{id}/enable/{enabled}")
async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans):
enable_training(session, id, enabled, trans)


@router.get("/export")
async def export_excel(session: SessionDep, trans: Trans, current_user: CurrentUser,
word: Optional[str] = Query(None, description="搜索术语(可选)")):
def inner():
_list = get_all_data_training(session, word, oid=current_user.oid)

data_list = []
for obj in _list:
_data = {
"question": obj.question,
"description": obj.description,
"datasource_name": obj.datasource_name,
"advanced_application_name": obj.advanced_application_name,
}
data_list.append(_data)

fields = []
fields.append(AxisObj(name=trans('i18n_data_training.data_training'), value='question'))
fields.append(AxisObj(name=trans('i18n_data_training.problem_description'), value='description'))
fields.append(AxisObj(name=trans('i18n_data_training.effective_data_sources'), value='datasource_name'))
if current_user.oid == 1:
fields.append(
AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name'))

md_data, _fields_list = LLMService.convert_object_array_for_pandas(fields, data_list)

df = pd.DataFrame(md_data, columns=_fields_list)

buffer = io.BytesIO()

with pd.ExcelWriter(buffer, engine='xlsxwriter',
engine_kwargs={'options': {'strings_to_numbers': False}}) as writer:
df.to_excel(writer, sheet_name='Sheet1', index=False)

buffer.seek(0)
return io.BytesIO(buffer.getvalue())

result = await asyncio.to_thread(inner)
return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
92 changes: 72 additions & 20 deletions backend/apps/data_training/curd/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,61 @@
from common.utils.embedding_threads import run_save_data_training_embeddings


def page_data_training(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
oid: Optional[int] = 1):
_list: List[DataTrainingInfoResult] = []

current_page = max(1, current_page)
page_size = max(10, page_size)

total_count = 0
total_pages = 0

def get_data_training_base_query(oid: int, name: Optional[str] = None):
"""
获取数据训练查询的基础查询结构
"""
if name and name.strip() != "":
keyword_pattern = f"%{name.strip()}%"
parent_ids_subquery = (
select(DataTraining.id)
.where(and_(DataTraining.question.ilike(keyword_pattern), DataTraining.oid == oid)) # LIKE查询条件
.where(and_(DataTraining.question.ilike(keyword_pattern), DataTraining.oid == oid))
)
else:
parent_ids_subquery = (
select(DataTraining.id).where(and_(DataTraining.oid == oid))
)

return parent_ids_subquery


def build_data_training_query(session: SessionDep, oid: int, name: Optional[str] = None,
paginate: bool = True, current_page: int = 1, page_size: int = 10):
"""
构建数据训练查询的通用方法
"""
parent_ids_subquery = get_data_training_base_query(oid, name)

# 计算总数
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
total_count = session.execute(count_stmt).scalar()
total_pages = (total_count + page_size - 1) // page_size

if current_page > total_pages:
if paginate:
# 分页处理
page_size = max(10, page_size)
total_pages = (total_count + page_size - 1) // page_size
current_page = max(1, min(current_page, total_pages)) if total_pages > 0 else 1

paginated_parent_ids = (
parent_ids_subquery
.order_by(DataTraining.create_time.desc())
.offset((current_page - 1) * page_size)
.limit(page_size)
.subquery()
)
else:
# 不分页,获取所有数据
total_pages = 1
current_page = 1
page_size = total_count if total_count > 0 else 1

paginated_parent_ids = (
parent_ids_subquery
.order_by(DataTraining.create_time.desc())
.offset((current_page - 1) * page_size)
.limit(page_size)
.subquery()
)
paginated_parent_ids = (
parent_ids_subquery
.order_by(DataTraining.create_time.desc())
.subquery()
)

# 构建主查询
stmt = (
select(
DataTraining.id,
Expand All @@ -74,6 +93,14 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in
.order_by(DataTraining.create_time.desc())
)

return stmt, total_count, total_pages, current_page, page_size


def execute_data_training_query(session: SessionDep, stmt) -> List[DataTrainingInfoResult]:
"""
执行查询并返回数据训练信息列表
"""
_list = []
result = session.execute(stmt)

for row in result:
Expand All @@ -90,9 +117,34 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in
advanced_application_name=row.advanced_application_name,
))

return _list


def page_data_training(session: SessionDep, current_page: int = 1, page_size: int = 10,
name: Optional[str] = None, oid: Optional[int] = 1):
"""
分页查询数据训练(原方法保持不变)
"""
stmt, total_count, total_pages, current_page, page_size = build_data_training_query(
session, oid, name, True, current_page, page_size
)
_list = execute_data_training_query(session, stmt)

return current_page, page_size, total_count, total_pages, _list


def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid: Optional[int] = 1):
"""
获取所有数据训练(不分页)
"""
stmt, total_count, total_pages, current_page, page_size = build_data_training_query(
session, oid, name, False
)
_list = execute_data_training_query(session, stmt)

return _list


def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
create_time = datetime.datetime.now()
if info.datasource is None and info.advanced_application is None:
Expand Down
7 changes: 6 additions & 1 deletion backend/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
"datasource_cannot_be_none": "Datasource cannot be empty",
"datasource_assistant_cannot_be_none": "Datasource or advanced application cannot both be empty",
"data_training_not_exists": "This example does not exist",
"exists_in_db": "This question already exists"
"exists_in_db": "This question already exists",
"data_training": "SQL Example Library",
"problem_description": "Problem Description",
"sample_sql": "Sample SQL",
"effective_data_sources": "Effective Data Sources",
"advanced_application": "Advanced Application"
},
"i18n_custom_prompt": {
"exists_in_db": "Template name already exists",
Expand Down
7 changes: 6 additions & 1 deletion backend/locales/ko-KR.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
"datasource_cannot_be_none": "데이터 소스는 비울 수 없습니다",
"datasource_assistant_cannot_be_none": "데이터 소스와 고급 애플리케이션을 모두 비울 수 없습니다",
"data_training_not_exists": "이 예시가 존재하지 않습니다",
"exists_in_db": "이 질문이 이미 존재합니다"
"exists_in_db": "이 질문이 이미 존재합니다",
"data_training": "SQL 예시 라이브러리",
"problem_description": "문제 설명",
"sample_sql": "예시 SQL",
"effective_data_sources": "유효 데이터 소스",
"advanced_application": "고급 애플리케이션"
},
"i18n_custom_prompt": {
"exists_in_db": "템플릿 이름이 이미 존재합니다",
Expand Down
7 changes: 6 additions & 1 deletion backend/locales/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
"datasource_cannot_be_none": "数据源不能为空",
"datasource_assistant_cannot_be_none": "数据源或高级应用不能都为空",
"data_training_not_exists": "该示例不存在",
"exists_in_db": "该问题已存在"
"exists_in_db": "该问题已存在",
"data_training": "SQL 示例库",
"problem_description": "问题描述",
"sample_sql": "示例 SQL",
"effective_data_sources": "生效数据源",
"advanced_application": "高级应用"
},
"i18n_custom_prompt": {
"exists_in_db": "模版名称已存在",
Expand Down
6 changes: 6 additions & 0 deletions frontend/src/api/training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,10 @@ export const trainingApi = {
deleteEmbedded: (params: any) => request.delete('/system/data-training', { data: params }),
getOne: (id: any) => request.get(`/system/data-training/${id}`),
enable: (id: any, enabled: any) => request.get(`/system/data-training/${id}/enable/${enabled}`),
export2Excel: (params: any) =>
request.get(`/system/data-training/export`, {
params,
responseType: 'blob',
requestOptions: { customError: true },
}),
}
1 change: 1 addition & 0 deletions frontend/src/i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"training_data_items": "Do you want to delete the {msg} selected SQL Sample items?",
"sql_statement": "SQL Statement",
"edit_training_data": "Edit SQL Sample",
"all_236_terms": "Export all {msg} sample SQL records?",
"sales_this_year": "Do you want to delete the SQL Sample: {msg}?"
},
"professional": {
Expand Down
1 change: 1 addition & 0 deletions frontend/src/i18n/ko-KR.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"training_data_items": "선택된 {msg}개의 예제 SQL을 삭제하시겠습니까?",
"sql_statement": "SQL 문",
"edit_training_data": "예제 SQL 편집",
"all_236_terms": "모든 {msg}개의 예시 SQL 기록을 내보내시겠습니까?",
"sales_this_year": "예제 SQL을 삭제하시겠습니까: {msg}?"
},
"professional": {
Expand Down
1 change: 1 addition & 0 deletions frontend/src/i18n/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"training_data_items": "是否删除选中的 {msg} 条示例 SQL?",
"sql_statement": "SQL 语句",
"edit_training_data": "编辑示例 SQL",
"all_236_terms": "是否导出全部 {msg} 条示例 SQL?",
"sales_this_year": "是否删除示例 SQL:{msg}?"
},
"professional": {
Expand Down
Loading