Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion backend/apps/terminology/api/terminology.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import asyncio
import io
from typing import Optional

import pandas as pd
from fastapi import APIRouter, Query
from fastapi.responses import StreamingResponse

from apps.chat.models.chat_model import AxisObj
from apps.chat.task.llm import LLMService
from apps.terminology.curd.terminology import page_terminology, create_terminology, update_terminology, \
delete_terminology, enable_terminology
delete_terminology, enable_terminology, get_all_terminology
from apps.terminology.models.terminology_model import TerminologyInfo
from common.core.deps import SessionDep, CurrentUser, Trans

Expand Down Expand Up @@ -42,3 +48,44 @@ async def delete(session: SessionDep, id_list: list[int]):
@router.get("/{id}/enable/{enabled}")
async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans):
enable_terminology(session, id, enabled, trans)


@router.get("/export")
async def export_excel(session: SessionDep, trans: Trans, current_user: CurrentUser,
word: Optional[str] = Query(None, description="搜索术语(可选)")):
def inner():
_list = get_all_terminology(session, word, oid=current_user.oid)

data_list = []
for obj in _list:
_data = {
"word": obj.word,
"other_words": ', '.join(obj.other_words) if obj.other_words else '',
"description": obj.description,
"all_data_sources": 'Y' if obj.specific_ds else 'N',
"datasource": ', '.join(obj.datasource_names) if obj.datasource_names else '',
}
data_list.append(_data)

fields = []
fields.append(AxisObj(name=trans('i18n_terminology.term_name'), value='word'))
fields.append(AxisObj(name=trans('i18n_terminology.synonyms'), value='other_words'))
fields.append(AxisObj(name=trans('i18n_terminology.term_description'), value='description'))
fields.append(AxisObj(name=trans('i18n_terminology.effective_data_sources'), value='datasource'))
fields.append(AxisObj(name=trans('i18n_terminology.all_data_sources'), value='all_data_sources'))

md_data, _fields_list = LLMService.convert_object_array_for_pandas(fields, data_list)

df = pd.DataFrame(md_data, columns=_fields_list)

buffer = io.BytesIO()

with pd.ExcelWriter(buffer, engine='xlsxwriter',
engine_kwargs={'options': {'strings_to_numbers': False}}) as writer:
df.to_excel(writer, sheet_name='Sheet1', index=False)

buffer.seek(0)
return io.BytesIO(buffer.getvalue())

result = await asyncio.to_thread(inner)
return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
260 changes: 118 additions & 142 deletions backend/apps/terminology/curd/terminology.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,18 @@
from common.utils.embedding_threads import run_save_terminology_embeddings


def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
oid: Optional[int] = 1):
_list: List[TerminologyInfo] = []

def get_terminology_base_query(oid: int, name: Optional[str] = None):
"""
获取术语查询的基础查询结构
"""
child = aliased(Terminology)

current_page = max(1, current_page)
page_size = max(10, page_size)

total_count = 0
total_pages = 0

if name and name.strip() != "":
keyword_pattern = f"%{name.strip()}%"
# 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点)
matched_ids_subquery = (
select(Terminology.id)
.where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid)) # LIKE查询条件
.where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid))
.subquery()
)

Expand All @@ -51,161 +45,118 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
)
.where(Terminology.pid.is_(None)) # 只取父节点
)
else:
parent_ids_subquery = (
select(Terminology.id)
.where(and_(Terminology.pid.is_(None), Terminology.oid == oid))
)

return parent_ids_subquery, child

count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
total_count = session.execute(count_stmt).scalar()
total_pages = (total_count + page_size - 1) // page_size

if current_page > total_pages:
current_page = 1
def build_terminology_query(session: SessionDep, oid: int, name: Optional[str] = None,
paginate: bool = True, current_page: int = 1, page_size: int = 10):
"""
构建术语查询的通用方法
"""
parent_ids_subquery, child = get_terminology_base_query(oid, name)

# 计算总数
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
total_count = session.execute(count_stmt).scalar()

if paginate:
# 分页处理
page_size = max(10, page_size)
total_pages = (total_count + page_size - 1) // page_size
current_page = max(1, min(current_page, total_pages)) if total_pages > 0 else 1

# 步骤3:获取分页后的父节点ID
paginated_parent_ids = (
parent_ids_subquery
.order_by(Terminology.create_time.desc())
.offset((current_page - 1) * page_size)
.limit(page_size)
.subquery()
)

# 步骤4:获取这些父节点的childrenNames
children_subquery = (
select(
child.pid,
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
)
.where(child.pid.isnot(None))
.group_by(child.pid)
.subquery()
)

# 创建子查询来获取数据源名称,添加类型转换
datasource_names_subquery = (
select(
func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'),
Terminology.id.label('term_id')
)
.where(Terminology.id.in_(paginated_parent_ids))
.subquery()
)

# 主查询
stmt = (
select(
Terminology.id,
Terminology.word,
Terminology.create_time,
Terminology.description,
Terminology.specific_ds,
Terminology.datasource_ids,
children_subquery.c.other_words,
func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names'),
Terminology.enabled
)
.outerjoin(
children_subquery,
Terminology.id == children_subquery.c.pid
)
# 关联数据源名称子查询和 CoreDatasource 表
.outerjoin(
datasource_names_subquery,
datasource_names_subquery.c.term_id == Terminology.id
)
.outerjoin(
CoreDatasource,
CoreDatasource.id == datasource_names_subquery.c.ds_id
)
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
.group_by(
Terminology.id,
Terminology.word,
Terminology.create_time,
Terminology.description,
Terminology.specific_ds,
Terminology.datasource_ids,
children_subquery.c.other_words,
Terminology.enabled
)
.order_by(Terminology.create_time.desc())
)
else:
parent_ids_subquery = (
select(Terminology.id)
.where(and_(Terminology.pid.is_(None), Terminology.oid == oid)) # 只取父节点
)
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
total_count = session.execute(count_stmt).scalar()
total_pages = (total_count + page_size - 1) // page_size

if current_page > total_pages:
current_page = 1
# 不分页,获取所有数据
total_pages = 1
current_page = 1
page_size = total_count if total_count > 0 else 1

paginated_parent_ids = (
parent_ids_subquery
.order_by(Terminology.create_time.desc())
.offset((current_page - 1) * page_size)
.limit(page_size)
.subquery()
)

children_subquery = (
select(
child.pid,
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
)
.where(child.pid.isnot(None))
.group_by(child.pid)
.subquery()
# 构建公共查询部分
children_subquery = (
select(
child.pid,
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
)
.where(child.pid.isnot(None))
.group_by(child.pid)
.subquery()
)

# 创建子查询来获取数据源名称
datasource_names_subquery = (
select(
func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'),
Terminology.id.label('term_id')
)
.where(Terminology.id.in_(paginated_parent_ids))
.subquery()
# 创建子查询来获取数据源名称
datasource_names_subquery = (
select(
func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'),
Terminology.id.label('term_id')
)
.where(Terminology.id.in_(paginated_parent_ids))
.subquery()
)

stmt = (
select(
Terminology.id,
Terminology.word,
Terminology.create_time,
Terminology.description,
Terminology.specific_ds,
Terminology.datasource_ids,
children_subquery.c.other_words,
func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names'),
Terminology.enabled
)
.outerjoin(
children_subquery,
Terminology.id == children_subquery.c.pid
)
# 关联数据源名称子查询和 CoreDatasource 表
.outerjoin(
datasource_names_subquery,
datasource_names_subquery.c.term_id == Terminology.id
)
.outerjoin(
CoreDatasource,
CoreDatasource.id == datasource_names_subquery.c.ds_id
)
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
.group_by(Terminology.id,
Terminology.word,
Terminology.create_time,
Terminology.description,
Terminology.specific_ds,
Terminology.datasource_ids,
children_subquery.c.other_words,
Terminology.enabled
)
.order_by(Terminology.create_time.desc())
stmt = (
select(
Terminology.id,
Terminology.word,
Terminology.create_time,
Terminology.description,
Terminology.specific_ds,
Terminology.datasource_ids,
children_subquery.c.other_words,
func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names'),
Terminology.enabled
)
.outerjoin(
children_subquery,
Terminology.id == children_subquery.c.pid
)
.outerjoin(
datasource_names_subquery,
datasource_names_subquery.c.term_id == Terminology.id
)
.outerjoin(
CoreDatasource,
CoreDatasource.id == datasource_names_subquery.c.ds_id
)
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
.group_by(
Terminology.id,
Terminology.word,
Terminology.create_time,
Terminology.description,
Terminology.specific_ds,
Terminology.datasource_ids,
children_subquery.c.other_words,
Terminology.enabled
)
.order_by(Terminology.create_time.desc())
)

return stmt, total_count, total_pages, current_page, page_size


def execute_terminology_query(session: SessionDep, stmt) -> List[TerminologyInfo]:
"""
执行查询并返回术语信息列表
"""
_list = []
result = session.execute(stmt)

for row in result:
Expand All @@ -221,9 +172,34 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
enabled=row.enabled if row.enabled is not None else False,
))

return _list


def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10,
name: Optional[str] = None, oid: Optional[int] = 1):
"""
分页查询术语(原方法保持不变)
"""
stmt, total_count, total_pages, current_page, page_size = build_terminology_query(
session, oid, name, True, current_page, page_size
)
_list = execute_terminology_query(session, stmt)

return current_page, page_size, total_count, total_pages, _list


def get_all_terminology(session: SessionDep, name: Optional[str] = None, oid: Optional[int] = 1):
"""
获取所有术语(不分页)
"""
stmt, total_count, total_pages, current_page, page_size = build_terminology_query(
session, oid, name, False
)
_list = execute_terminology_query(session, stmt)

return _list


def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
create_time = datetime.datetime.now()

Expand Down
Loading