Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions backend/apps/data_training/curd/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid:
return _list


def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans, skip_embedding: bool = False):
"""
创建单个数据训练记录
Args:
skip_embedding: 是否跳过embedding处理(用于批量插入)
"""
# 基本验证
if not info.question or not info.question.strip():
Expand Down Expand Up @@ -203,8 +205,9 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
session.refresh(data_training)
session.commit()

# 处理embedding
run_save_data_training_embeddings([data_training.id])
# 处理embedding(批量插入时跳过)
if not skip_embedding:
run_save_data_training_embeddings([data_training.id])

return data_training.id

Expand Down Expand Up @@ -247,11 +250,11 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
raise Exception(trans("i18n_data_training.exists_in_db"))

stmt = update(DataTraining).where(and_(DataTraining.id == info.id)).values(
question=info.question,
description=info.description,
question=info.question.strip(),
description=info.description.strip(),
datasource=info.datasource,
enabled=info.enabled,
advanced_application=info.advanced_application,
enabled=info.enabled if info.enabled is not None else True
)
session.execute(stmt)
session.commit()
Expand Down Expand Up @@ -374,8 +377,8 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
if valid_records:
for info in valid_records:
try:
# 直接复用create_training方法
training_id = create_training(session, info, oid, trans)
# 直接复用create_training方法,跳过embedding处理
training_id = create_training(session, info, oid, trans, skip_embedding=True)
inserted_ids.append(training_id)
success_count += 1

Expand All @@ -387,6 +390,15 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
'errors': [str(e)]
})

# 批量处理embedding(只在最后执行一次)
if success_count > 0 and inserted_ids:
try:
run_save_data_training_embeddings(inserted_ids)
except Exception as e:
# 如果embedding处理失败,记录错误但不回滚数据
print(f"Embedding processing failed: {str(e)}")
# 可以选择将embedding失败的信息记录到日志或返回给调用方

return {
'success_count': success_count,
'failed_records': failed_records,
Expand Down
74 changes: 44 additions & 30 deletions backend/apps/terminology/curd/terminology.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,12 @@ def get_all_terminology(session: SessionDep, name: Optional[str] = None, oid: Op
return _list


def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans,
skip_embedding: bool = False):
"""
创建单个术语记录
Args:
skip_embedding: 是否跳过embedding处理(用于批量插入)
"""
# 基本验证
if not info.word or not info.word.strip():
Expand All @@ -221,16 +224,16 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
raise Exception(trans("i18n_terminology.datasource_cannot_be_none"))

parent = Terminology(
word=info.word,
word=info.word.strip(),
create_time=create_time,
description=info.description,
description=info.description.strip(),
oid=oid,
specific_ds=specific_ds,
enabled=info.enabled,
datasource_ids=datasource_ids
)

words = [info.word]
words = [info.word.strip()]
for child_word in info.other_words:
# 先检查是否为空字符串
if not child_word or child_word.strip() == "":
Expand All @@ -239,7 +242,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
if child_word in words:
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
else:
words.append(child_word)
words.append(child_word.strip())

# 基础查询条件(word 和 oid 必须满足)
base_query = and_(
Expand Down Expand Up @@ -288,7 +291,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
child_list.append(
Terminology(
pid=parent.id,
word=other_word,
word=other_word.strip(),
create_time=create_time,
oid=oid,
enabled=info.enabled,
Expand All @@ -303,8 +306,9 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra

session.commit()

# 处理embedding
run_save_terminology_embeddings([parent.id])
# 处理embedding(批量插入时跳过)
if not skip_embedding:
run_save_terminology_embeddings([parent.id])

return parent.id

Expand Down Expand Up @@ -380,19 +384,9 @@ def batch_create_terminology(session: SessionDep, info_list: List[TerminologyInf
# 基本验证
if not info.word or not info.word.strip():
error_messages.append(trans("i18n_terminology.word_cannot_be_empty"))
failed_records.append({
'data': info,
'errors': error_messages
})
continue

if not info.description or not info.description.strip():
error_messages.append(trans("i18n_terminology.description_cannot_be_empty"))
failed_records.append({
'data': info,
'errors': error_messages
})
continue

# 根据specific_ds决定是否验证数据源
specific_ds = info.specific_ds if info.specific_ds is not None else False
Expand Down Expand Up @@ -455,8 +449,8 @@ def batch_create_terminology(session: SessionDep, info_list: List[TerminologyInf
if valid_records:
for info in valid_records:
try:
# 直接复用create_terminology方法
terminology_id = create_terminology(session, info, oid, trans)
# 直接复用create_terminology方法,跳过embedding处理
terminology_id = create_terminology(session, info, oid, trans, skip_embedding=True)
inserted_ids.append(terminology_id)
success_count += 1

Expand All @@ -468,6 +462,15 @@ def batch_create_terminology(session: SessionDep, info_list: List[TerminologyInf
'errors': [str(e)]
})

# 批量处理embedding(只在最后执行一次)
if success_count > 0 and inserted_ids:
try:
run_save_terminology_embeddings(inserted_ids)
except Exception as e:
# 如果embedding处理失败,记录错误但不回滚数据
print(f"Terminology embedding processing failed: {str(e)}")
# 可以选择将embedding失败的信息记录到日志或返回给调用方

return {
'success_count': success_count,
'failed_records': failed_records,
Expand All @@ -492,12 +495,12 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
if not datasource_ids:
raise Exception(trans("i18n_terminology.datasource_cannot_be_none"))

words = [info.word]
words = [info.word.strip()]
for child in info.other_words:
if child in words:
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
else:
words.append(child)
words.append(child.strip())

# 基础查询条件(word 和 oid 必须满足)
base_query = and_(
Expand Down Expand Up @@ -539,8 +542,8 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
raise Exception(trans("i18n_terminology.exists_in_db"))

stmt = update(Terminology).where(and_(Terminology.id == info.id)).values(
word=info.word,
description=info.description,
word=info.word.strip(),
description=info.description.strip(),
specific_ds=specific_ds,
datasource_ids=datasource_ids,
enabled=info.enabled,
Expand All @@ -553,16 +556,27 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
session.commit()

create_time = datetime.datetime.now()
_list: List[Terminology] = []
# 插入子记录(其他词)
child_list = []
if info.other_words:
for other_word in info.other_words:
if other_word.strip() == "":
continue
_list.append(
Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid,
specific_ds=specific_ds, datasource_ids=datasource_ids, enabled=info.enabled))
session.bulk_save_objects(_list)
session.flush()
child_list.append(
Terminology(
pid=info.id,
word=other_word.strip(),
create_time=create_time,
oid=oid,
enabled=info.enabled,
specific_ds=specific_ds,
datasource_ids=datasource_ids
)
)

if child_list:
session.bulk_save_objects(child_list)
session.flush()
session.commit()

# embedding
Expand Down
7 changes: 6 additions & 1 deletion backend/locales/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@
"prompt_word_name": "提示词名称",
"prompt_word_content": "提示词内容",
"effective_data_sources": "生效数据源",
"all_data_sources": "所有数据源"
"all_data_sources": "所有数据源",
"name_cannot_be_empty": "名称不能为空",
"prompt_cannot_be_empty": "提示词内容不能为空",
"type_cannot_be_empty": "类型不能为空",
"datasource_not_found": "找不到数据源",
"datasource_cannot_be_none": "数据源不能为空",
},
"i18n_excel_export": {
"data_is_empty": "表单数据为空,无法导出数据"
Expand Down
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
"pyyaml (>=6.0.2,<7.0.0)",
"fastapi-mcp (>=0.3.4,<0.4.0)",
"tabulate>=0.9.0",
"sqlbot-xpack>=0.0.3.45,<1.0.0",
"sqlbot-xpack>=0.0.3.46,<1.0.0",
"fastapi-cache2>=0.2.2",
"sqlparse>=0.5.3",
"redis>=6.2.0",
Expand Down
Loading