Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 62 additions & 116 deletions backend/apps/data_training/curd/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid:


def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
"""
创建单个数据训练记录
"""
# 基本验证
if not info.question or not info.question.strip():
raise Exception(trans("i18n_data_training.question_cannot_be_empty"))
Expand All @@ -154,45 +157,56 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
raise Exception(trans("i18n_data_training.description_cannot_be_empty"))

create_time = datetime.datetime.now()

# 检查数据源和高级应用不能同时为空
if info.datasource is None and info.advanced_application is None:
if oid == 1:
raise Exception(trans("i18n_data_training.datasource_assistant_cannot_be_none"))
else:
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))

parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid,
datasource=info.datasource, enabled=info.enabled,
advanced_application=info.advanced_application)

stmt = select(DataTraining.id).where(and_(DataTraining.question == info.question, DataTraining.oid == oid))
# 检查重复记录
stmt = select(DataTraining.id).where(
and_(DataTraining.question == info.question.strip(), DataTraining.oid == oid)
)

if info.datasource is not None and info.advanced_application is not None:
stmt = stmt.where(
or_(DataTraining.datasource == info.datasource,
DataTraining.advanced_application == info.advanced_application))
or_(
DataTraining.datasource == info.datasource,
DataTraining.advanced_application == info.advanced_application
)
)
elif info.datasource is not None and info.advanced_application is None:
stmt = stmt.where(and_(DataTraining.datasource == info.datasource))
stmt = stmt.where(DataTraining.datasource == info.datasource)
elif info.datasource is None and info.advanced_application is not None:
stmt = stmt.where(and_(DataTraining.advanced_application == info.advanced_application))
stmt = stmt.where(DataTraining.advanced_application == info.advanced_application)

exists = session.query(stmt.exists()).scalar()

if exists:
raise Exception(trans("i18n_data_training.exists_in_db"))

result = DataTraining(**parent.model_dump())
# 创建记录
data_training = DataTraining(
question=info.question.strip(),
description=info.description.strip(),
oid=oid,
datasource=info.datasource,
advanced_application=info.advanced_application,
create_time=create_time,
enabled=info.enabled if info.enabled is not None else True
)

session.add(parent)
session.add(data_training)
session.flush()
session.refresh(parent)

result.id = parent.id
session.refresh(data_training)
session.commit()

# embedding
run_save_data_training_embeddings([result.id])
# 处理embedding
run_save_data_training_embeddings([data_training.id])

return result.id
return data_training.id


def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
Expand Down Expand Up @@ -250,14 +264,7 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans

def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo], oid: int, trans: Trans):
"""
批量创建数据训练记录
Args:
session: 数据库会话
info_list: DataTrainingInfo对象列表
oid: 组织ID
trans: 翻译对象
Returns:
dict: 包含成功数量、失败记录和统计信息的结果字典
批量创建数据训练记录(复用单条插入逻辑)
"""
if not info_list:
return {
Expand All @@ -268,48 +275,45 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
'deduplicated_count': 0
}

create_time = datetime.datetime.now()
failed_records = [] # 存储失败的数据和原因
failed_records = []
success_count = 0
inserted_ids = [] # 存储成功插入的ID
inserted_ids = []

# 第一步:数据去重
unique_records = {}
duplicate_records = [] # 存储重复的数据
duplicate_records = []

for info in info_list:
# 创建唯一标识:问题 + 数据源名称 + 高级应用名称
# 创建唯一标识
unique_key = (
info.question.strip().lower() if info.question else "",
info.datasource_name.strip().lower() if info.datasource_name else "",
info.advanced_application_name.strip().lower() if info.advanced_application_name else ""
)

if unique_key in unique_records:
# 如果是重复数据,记录到重复列表中
duplicate_records.append(info)
else:
unique_records[unique_key] = info

# 将去重后的数据转换为列表
deduplicated_list = list(unique_records.values())

# 预加载数据源名称到ID的映射(CoreDatasource需要判断oid)
# 预加载数据源和高级应用名称到ID的映射
datasource_name_to_id = {}
datasource_stmt = select(CoreDatasource.id, CoreDatasource.name).where(CoreDatasource.oid == oid)
datasource_result = session.execute(datasource_stmt).all()
for ds in datasource_result:
datasource_name_to_id[ds.name.strip()] = ds.id

# 只有在oid=1时才预加载高级应用名称到ID的映射
assistant_name_to_id = {}
if oid == 1:
assistant_stmt = select(AssistantModel.id, AssistantModel.name).where(AssistantModel.type == 1)
assistant_result = session.execute(assistant_stmt).all()
for assistant in assistant_result:
assistant_name_to_id[assistant.name.strip()] = assistant.id

# 验证和准备数据
# 验证和转换数据
valid_records = []
for info in deduplicated_list:
error_messages = []
Expand All @@ -321,15 +325,15 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
if not info.description or not info.description.strip():
error_messages.append(trans("i18n_data_training.description_cannot_be_empty"))

# 数据源验证
# 数据源验证和转换
datasource_id = None
if info.datasource_name and info.datasource_name.strip():
if info.datasource_name.strip() in datasource_name_to_id:
datasource_id = datasource_name_to_id[info.datasource_name.strip()]
else:
error_messages.append(trans("i18n_data_training.datasource_not_found").format(info.datasource_name))

# 高级应用验证(只有在oid=1时才需要)
# 高级应用验证和转换
advanced_application_id = None
if oid == 1 and info.advanced_application_name and info.advanced_application_name.strip():
if info.advanced_application_name.strip() in assistant_name_to_id:
Expand All @@ -346,101 +350,43 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
if not datasource_id:
error_messages.append(trans("i18n_data_training.datasource_cannot_be_none"))

# 如果有错误,添加到失败列表
if error_messages:
# 返回原始的info对象,不包含转换后的ID
failed_records.append({
'data': info, # 直接返回原始传入的数据
'data': info,
'errors': error_messages
})
continue

# 检查数据库中是否已存在重复记录
stmt = select(DataTraining.id).where(
and_(
DataTraining.question == info.question.strip(),
DataTraining.oid == oid
)
# 创建处理后的DataTrainingInfo对象
processed_info = DataTrainingInfo(
question=info.question.strip(),
description=info.description.strip(),
datasource=datasource_id,
datasource_name=info.datasource_name,
advanced_application=advanced_application_id,
advanced_application_name=info.advanced_application_name,
enabled=info.enabled if info.enabled is not None else True
)

# 根据oid决定重复检查条件
if oid == 1:
if datasource_id is not None and advanced_application_id is not None:
stmt = stmt.where(
or_(
DataTraining.datasource == datasource_id,
DataTraining.advanced_application == advanced_application_id
)
)
elif datasource_id is not None:
stmt = stmt.where(DataTraining.datasource == datasource_id)
elif advanced_application_id is not None:
stmt = stmt.where(DataTraining.advanced_application == advanced_application_id)
else:
# oid != 1时,只检查数据源
if datasource_id is not None:
stmt = stmt.where(DataTraining.datasource == datasource_id)

exists = session.query(stmt.exists()).scalar()

if exists:
# 返回原始的info对象
failed_records.append({
'data': info, # 直接返回原始传入的数据
'errors': [trans("i18n_data_training.exists_in_db")]
})
continue

# 验证通过,添加到有效记录
valid_records.append({
'info': info,
'datasource_id': datasource_id,
'advanced_application_id': advanced_application_id
})
valid_records.append(processed_info)

# 批量插入有效记录
# 使用事务处理有效记录
if valid_records:
data_training_objects = []
for record in valid_records:
info = record['info']
data_training = DataTraining(
question=info.question.strip(),
description=info.description.strip(),
oid=oid,
datasource=record['datasource_id'],
advanced_application=record['advanced_application_id'] if oid == 1 else None, # 只有oid=1才设置高级应用
create_time=create_time,
enabled=info.enabled if info.enabled is not None else True
)
data_training_objects.append(data_training)

try:
# 批量插入
session.bulk_save_objects(data_training_objects, return_defaults=True)
session.commit()
for info in valid_records:
try:
# 直接复用create_training方法
training_id = create_training(session, info, oid, trans)
inserted_ids.append(training_id)
success_count += 1

# 获取插入的ID
for obj in data_training_objects:
if obj.id is not None: # 确保ID已经被赋值
inserted_ids.append(obj.id)
success_count += 1

except Exception as e:
session.rollback()
# 将所有的有效记录标记为失败
for record in valid_records:
# 返回原始的info对象
except Exception as e:
# 如果单条插入失败,回滚当前记录
session.rollback()
failed_records.append({
'data': record['info'], # 直接返回原始传入的数据
'data': info,
'errors': [str(e)]
})
success_count = 0

# 批量处理embedding
if success_count > 0 and inserted_ids:
run_save_data_training_embeddings(inserted_ids)

# 返回结果,包含去重统计信息
return {
'success_count': success_count,
'failed_records': failed_records,
Expand Down