Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/alembic/versions/047_table_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('core_table', sa.Column('embedding', sa.Text(), nullable=True))
op.add_column('core_datasource', sa.Column('embedding', sa.Text(), nullable=True))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('core_table', 'embedding')
op.drop_column('core_datasource', 'embedding')
# ### end Alembic commands ###
8 changes: 7 additions & 1 deletion backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from apps.db.engine import get_engine_config, get_engine_conn
from common.core.config import settings
from common.core.deps import SessionDep, CurrentUser, Trans
from common.utils.embedding_threads import run_save_table_embeddings
from common.utils.embedding_threads import run_save_table_embeddings, run_save_ds_embeddings
from common.utils.utils import deepcopy_ignore_extra
from .table import get_tables_by_ds_id
from ..crud.field import delete_field_by_ds_id, update_field
Expand Down Expand Up @@ -105,6 +105,8 @@ def update_ds(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreData
setattr(record, field, value)
session.add(record)
session.commit()

run_save_ds_embeddings([ds.id])
return ds


Expand Down Expand Up @@ -197,6 +199,7 @@ def sync_table(session: SessionDep, ds: CoreDatasource, tables: List[CoreTable])

# do table embedding
run_save_table_embeddings(id_list)
run_save_ds_embeddings([ds.id])


def sync_fields(session: SessionDep, ds: CoreDatasource, table: CoreTable, fields: List[ColumnSchema]):
Expand Down Expand Up @@ -238,20 +241,23 @@ def update_table_and_fields(session: SessionDep, data: TableObj):

# do table embedding
run_save_table_embeddings([data.table.id])
run_save_ds_embeddings([data.table.ds_id])


def updateTable(session: SessionDep, table: CoreTable):
update_table(session, table)

# do table embedding
run_save_table_embeddings([table.id])
run_save_ds_embeddings([table.ds_id])


def updateField(session: SessionDep, field: CoreField):
update_field(session, field)

# do table embedding
run_save_table_embeddings([field.table_id])
run_save_ds_embeddings([field.ds_id])


def preview(session: SessionDep, current_user: CurrentUser, id: int, data: TableObj):
Expand Down
73 changes: 67 additions & 6 deletions backend/apps/datasource/crud/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from common.core.config import settings
from common.core.deps import SessionDep
from common.utils.utils import SQLBotLogUtil
from ..models.datasource import CoreTable, CoreField
from ..models.datasource import CoreTable, CoreField, CoreDatasource


def delete_table_by_ds_id(session: SessionDep, id: int):
Expand All @@ -30,18 +30,24 @@ def update_table(session: SessionDep, item: CoreTable):
session.commit()


def run_fill_empty_table_embedding(session_maker):
def run_fill_empty_table_and_ds_embedding(session_maker):
try:
if not settings.TABLE_EMBEDDING_ENABLED:
return

SQLBotLogUtil.info('get tables')
session = session_maker()

SQLBotLogUtil.info('get tables')
stmt = select(CoreTable.id).where(and_(CoreTable.embedding.is_(None)))
results = session.execute(stmt).scalars().all()
SQLBotLogUtil.info('result: ' + str(len(results)))

save_table_embedding(session, results)
SQLBotLogUtil.info('table result: ' + str(len(results)))
save_table_embedding(session_maker, results)

SQLBotLogUtil.info('get datasource')
ds_stmt = select(CoreDatasource.id).where(and_(CoreDatasource.embedding.is_(None)))
ds_results = session.execute(ds_stmt).scalars().all()
SQLBotLogUtil.info('datasource result: ' + str(len(ds_results)))
save_ds_embedding(session_maker, ds_results)
except Exception:
traceback.print_exc()
finally:
Expand Down Expand Up @@ -98,3 +104,58 @@ def save_table_embedding(session_maker, ids: List[int]):
traceback.print_exc()
finally:
session_maker.remove()


def save_ds_embedding(session_maker, ids: List[int]):
if not settings.TABLE_EMBEDDING_ENABLED:
return

if not ids or len(ids) == 0:
return
try:
SQLBotLogUtil.info('start datasource embedding')
start_time = time.time()
model = EmbeddingModelCache.get_model()
session = session_maker()
for _id in ids:
schema_table = ''
ds = session.query(CoreDatasource).filter(CoreDatasource.id == _id).first()
schema_table += f"{ds.name}, {ds.description}\n"
tables = session.query(CoreTable).filter(CoreTable.ds_id == ds.id).all()
for table in tables:
fields = session.query(CoreField).filter(CoreField.table_id == table.id).all()

schema_table += f"# Table: {table.table_name}"
table_comment = ''
if table.custom_comment:
table_comment = table.custom_comment.strip()
if table_comment == '':
schema_table += '\n[\n'
else:
schema_table += f", {table_comment}\n[\n"

if fields:
field_list = []
for field in fields:
field_comment = ''
if field.custom_comment:
field_comment = field.custom_comment.strip()
if field_comment == '':
field_list.append(f"({field.field_name}:{field.field_type})")
else:
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
schema_table += ",\n".join(field_list)
schema_table += '\n]\n'
# table_schema.append(schema_table)
emb = json.dumps(model.embed_query(schema_table))

stmt = update(CoreDatasource).where(and_(CoreDatasource.id == _id)).values(embedding=emb)
session.execute(stmt)
session.commit()

end_time = time.time()
SQLBotLogUtil.info('datasource embedding finished in: ' + str(end_time - start_time) + ' seconds')
except Exception:
traceback.print_exc()
finally:
session_maker.remove()
83 changes: 56 additions & 27 deletions backend/apps/datasource/embedding/ds_embedding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Author: Junjun
# Date: 2025/9/18
import json
import time
import traceback
from typing import Optional

from apps.ai_model.embedding import EmbeddingModelCache
from apps.datasource.crud.datasource import get_table_schema
from apps.datasource.embedding.utils import cosine_similarity
from apps.datasource.models.datasource import CoreDatasource
from apps.system.crud.assistant import AssistantOutDs
Expand All @@ -18,42 +18,71 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
question: str,
current_assistant: Optional[CurrentAssistant] = None):
_list = []
if current_assistant and current_assistant.type != 4:
if current_assistant and current_assistant.type == 1:
if out_ds.ds_list:
for _ds in out_ds.ds_list:
ds = out_ds.get_ds(_ds.id)
table_schema = out_ds.get_db_schema(_ds.id, question, embedding=False)
ds_info = f"{ds.name}, {ds.description}\n"
ds_schema = ds_info + table_schema
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})

if _list:
try:
text = [s.get('ds_schema') for s in _list]

model = EmbeddingModelCache.get_model()
results = model.embed_documents(text)

q_embedding = model.embed_query(question)
for index in range(len(results)):
item = results[index]
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)

_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
# print(len(_list))
SQLBotLogUtil.info(json.dumps(
[{"id": ele.get("id"), "name": ele.get("ds").name,
"cosine_similarity": ele.get("cosine_similarity")}
for ele in _list]))
ds = _list[0].get('ds')
return {"id": ds.id, "name": ds.name, "description": ds.description}
except Exception:
traceback.print_exc()
else:
for _ds in _ds_list:
if _ds.get('id'):
ds = session.get(CoreDatasource, _ds.get('id'))
table_schema = get_table_schema(session, current_user, ds, question, embedding=False)
ds_info = f"{ds.name}, {ds.description}\n"
ds_schema = ds_info + table_schema
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
# table_schema = get_table_schema(session, current_user, ds, question, embedding=False)
# ds_info = f"{ds.name}, {ds.description}\n"
# ds_schema = ds_info + table_schema
_list.append({"id": ds.id, "cosine_similarity": 0.0, "ds": ds, "embedding": ds.embedding})

if _list:
try:
# text = [s.get('ds_schema') for s in _list]

model = EmbeddingModelCache.get_model()
start_time = time.time()
# results = model.embed_documents(text)
results = [item.get('embedding') for item in _list]

q_embedding = model.embed_query(question)
for index in range(len(results)):
item = results[index]
if item:
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)

if _list:
try:
text = [s.get('ds_schema') for s in _list]

model = EmbeddingModelCache.get_model()
results = model.embed_documents(text)

q_embedding = model.embed_query(question)
for index in range(len(results)):
item = results[index]
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)

_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
# print(len(_list))
SQLBotLogUtil.info(json.dumps(
[{"id": ele.get("id"), "name": ele.get("ds").name, "cosine_similarity": ele.get("cosine_similarity")}
for ele in _list]))
ds = _list[0].get('ds')
return {"id": ds.id, "name": ds.name, "description": ds.description}
except Exception:
traceback.print_exc()
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
# print(len(_list))
end_time = time.time()
SQLBotLogUtil.info(str(end_time - start_time))
SQLBotLogUtil.info(json.dumps(
[{"id": ele.get("id"), "name": ele.get("ds").name,
"cosine_similarity": ele.get("cosine_similarity")}
for ele in _list]))
ds = _list[0].get('ds')
return {"id": ds.id, "name": ds.name, "description": ds.description}
except Exception:
traceback.print_exc()
return _list
8 changes: 6 additions & 2 deletions backend/apps/datasource/embedding/table_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def calc_table_embedding(tables: list[dict], question: str):
# text = [s.get('schema_table') for s in _list]
#
model = EmbeddingModelCache.get_model()
# start_time = time.time()
start_time = time.time()
# results = model.embed_documents(text)
# end_time = time.time()
# SQLBotLogUtil.info(str(end_time - start_time))
Expand All @@ -67,7 +67,11 @@ def calc_table_embedding(tables: list[dict], question: str):
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
# print(len(_list))
SQLBotLogUtil.info(json.dumps(_list))
end_time = time.time()
SQLBotLogUtil.info(str(end_time - start_time))
SQLBotLogUtil.info(json.dumps([{"id": ele.get('id'), "schema_table": ele.get('schema_table'),
"cosine_similarity": ele.get('cosine_similarity')}
for ele in _list]))
return _list
except Exception:
traceback.print_exc()
Expand Down
1 change: 1 addition & 0 deletions backend/apps/datasource/models/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class CoreDatasource(SQLModel, table=True):
num: str = Field(max_length=256, nullable=True)
oid: int = Field(sa_column=Column(BigInteger()))
table_relation: List = Field(sa_column=Column(JSONB, nullable=True))
embedding: str = Field(sa_column=Column(Text, nullable=True))


class CoreTable(SQLModel, table=True):
Expand Down
11 changes: 8 additions & 3 deletions backend/common/utils/embedding_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def run_save_table_embeddings(ids: List[int]):
executor.submit(save_table_embedding, session_maker, ids)


def fill_empty_table_embeddings():
from apps.datasource.crud.table import run_fill_empty_table_embedding
executor.submit(run_fill_empty_table_embedding, session_maker)
def run_save_ds_embeddings(ids: List[int]):
from apps.datasource.crud.table import save_ds_embedding
executor.submit(save_ds_embedding, session_maker, ids)


def fill_empty_table_and_ds_embeddings():
from apps.datasource.crud.table import run_fill_empty_table_and_ds_embedding
executor.submit(run_fill_empty_table_and_ds_embedding, session_maker)
8 changes: 4 additions & 4 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from alembic import command
from apps.api import api_router
from common.utils.embedding_threads import fill_empty_table_embeddings
from common.utils.embedding_threads import fill_empty_table_and_ds_embeddings
from apps.system.crud.aimodel_manage import async_model_info
from apps.system.crud.assistant import init_dynamic_cors
from apps.system.middleware.auth import TokenMiddleware
Expand All @@ -36,8 +36,8 @@ def init_data_training_embedding_data():
fill_empty_data_training_embeddings()


def init_table_embedding():
fill_empty_table_embeddings()
def init_table_and_ds_embedding():
fill_empty_table_and_ds_embeddings()


@asynccontextmanager
Expand All @@ -47,7 +47,7 @@ async def lifespan(app: FastAPI):
init_dynamic_cors(app)
init_terminology_embedding_data()
init_data_training_embedding_data()
init_table_embedding()
init_table_and_ds_embedding()
SQLBotLogUtil.info("✅ SQLBot 初始化完成")
await sqlbot_xpack.core.clean_xpack_cache()
await async_model_info() # 异步加密已有模型的密钥和地址
Expand Down