Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 151 additions & 3 deletions backend/apps/data_training/api/data_training.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import asyncio
import hashlib
import io
import os
import uuid
from http.client import HTTPException
from typing import Optional

import pandas as pd
from fastapi import APIRouter, Query
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, File, UploadFile, Query
from fastapi.responses import StreamingResponse, FileResponse

from apps.chat.models.chat_model import AxisObj
from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \
enable_training, get_all_data_training
enable_training, get_all_data_training, batch_create_training
from apps.data_training.models.data_training_model import DataTrainingInfo
from common.core.config import settings
from common.core.deps import SessionDep, CurrentUser, Trans
from common.utils.data_format import DataFormat

Expand Down Expand Up @@ -90,3 +95,146 @@ def inner():

result = await asyncio.to_thread(inner)
return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")


path = settings.EXCEL_PATH

from sqlalchemy.orm import sessionmaker, scoped_session
from common.core.db import engine
from sqlmodel import Session

session_maker = scoped_session(sessionmaker(bind=engine, class_=Session))


@router.post("/uploadExcel")
async def upload_excel(trans: Trans, current_user: CurrentUser, file: UploadFile = File(...)):
ALLOWED_EXTENSIONS = {"xlsx", "xls"}
if not file.filename.lower().endswith(tuple(ALLOWED_EXTENSIONS)):
raise HTTPException(400, "Only support .xlsx/.xls")

os.makedirs(path, exist_ok=True)
base_filename = f"{file.filename.split('.')[0]}_{hashlib.sha256(uuid.uuid4().bytes).hexdigest()[:10]}"
filename = f"{base_filename}.{file.filename.split('.')[1]}"
save_path = os.path.join(path, filename)
with open(save_path, "wb") as f:
f.write(await file.read())

oid = current_user.oid

use_cols = [0, 1, 2] # 问题, 描述, 数据源名称
# 根据oid确定要读取的列
if oid == 1:
use_cols = [0, 1, 2, 3] # 问题, 描述, 数据源名称, 高级应用名称

def inner():

session = session_maker()

sheet_names = pd.ExcelFile(save_path).sheet_names

import_data = []

for sheet_name in sheet_names:

df = pd.read_excel(
save_path,
sheet_name=sheet_name,
engine='calamine',
header=0,
usecols=use_cols,
dtype=str
).fillna("")

for index, row in df.iterrows():
# 跳过空行
if row.isnull().all():
continue

question = row[0].strip() if pd.notna(row[0]) and row[0].strip() else None
description = row[1].strip() if pd.notna(row[1]) and row[1].strip() else None
datasource_name = row[2].strip() if pd.notna(row[2]) and row[2].strip() else None

advanced_application_name = None
if oid == 1 and len(row) > 3:
advanced_application_name = row[3].strip() if pd.notna(row[3]) and row[3].strip() else None

if oid == 1:
import_data.append(
DataTrainingInfo(oid=oid, question=question, description=description,
datasource_name=datasource_name,
advanced_application_name=advanced_application_name))
else:
import_data.append(
DataTrainingInfo(oid=oid, question=question, description=description,
datasource_name=datasource_name))

res = batch_create_training(session, import_data, oid, trans)

failed_records = res['failed_records']

error_excel_filename = None

if len(failed_records) > 0:
data_list = []
for obj in failed_records:
_data = {
"question": obj['data'].question,
"description": obj['data'].description,
"datasource_name": obj['data'].datasource_name,
"advanced_application_name": obj['data'].advanced_application_name,
"errors": obj['errors']
}
data_list.append(_data)

fields = []
fields.append(AxisObj(name=trans('i18n_data_training.problem_description'), value='question'))
fields.append(AxisObj(name=trans('i18n_data_training.sample_sql'), value='description'))
fields.append(AxisObj(name=trans('i18n_data_training.effective_data_sources'), value='datasource_name'))
if current_user.oid == 1:
fields.append(
AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name'))
fields.append(AxisObj(name=trans('i18n_data_training.error_info'), value='errors'))

md_data, _fields_list = DataFormat.convert_object_array_for_pandas(fields, data_list)

df = pd.DataFrame(md_data, columns=_fields_list)
error_excel_filename = f"{base_filename}_error.xlsx"
save_error_path = os.path.join(path, error_excel_filename)
# 保存 DataFrame 到 Excel
df.to_excel(save_error_path, index=False)

return {
'success_count': res['success_count'],
'failed_count': len(failed_records),
'duplicate_count': res['duplicate_count'],
'original_count': res['original_count'],
'error_excel_filename': error_excel_filename,
}

return await asyncio.to_thread(inner)


@router.get("/download-fail-info/{filename}")
async def download_excel(filename: str, trans: Trans):
"""
根据文件路径下载 Excel 文件
"""
file_path = os.path.join(path, filename)

# 检查文件是否存在
if not os.path.exists(file_path):
raise HTTPException(404, "File Not Exists")

# 检查文件是否是 Excel 文件
if not filename.endswith('_error.xlsx'):
raise HTTPException(400, "Only support _error.xlsx")

# 获取文件名
filename = os.path.basename(file_path)

# 返回文件
return FileResponse(
path=file_path,
filename=filename,
media_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
)
Loading