Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions backend/apps/data_training/curd/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import dicttoxml
from sqlalchemy import and_, select, func, delete, update, or_
from sqlalchemy import text
from sqlalchemy.orm.session import Session

from apps.ai_model.embedding import EmbeddingModelCache
from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining
Expand Down Expand Up @@ -160,24 +159,30 @@ def delete_training(session: SessionDep, ids: list[int]):
# executor.submit(run_fill_empty_embeddings)


def run_fill_empty_embeddings(session: Session):
if not settings.EMBEDDING_ENABLED:
return
def run_fill_empty_embeddings(session_maker):
try:
if not settings.EMBEDDING_ENABLED:
return

stmt = select(DataTraining.id).where(and_(DataTraining.embedding.is_(None)))
results = session.execute(stmt).scalars().all()
session = session_maker()
stmt = select(DataTraining.id).where(and_(DataTraining.embedding.is_(None)))
results = session.execute(stmt).scalars().all()

save_embeddings(session, results)
save_embeddings(session_maker, results)
except Exception:
traceback.print_exc()
finally:
session_maker.remove()


def save_embeddings(session: Session, ids: List[int]):
def save_embeddings(session_maker, ids: List[int]):
if not settings.EMBEDDING_ENABLED:
return

if not ids or len(ids) == 0:
return
try:

session = session_maker()
_list = session.query(DataTraining).filter(and_(DataTraining.id.in_(ids))).all()

_question_list = [item.question for item in _list]
Expand All @@ -194,6 +199,8 @@ def save_embeddings(session: Session, ids: List[int]):

except Exception:
traceback.print_exc()
finally:
session_maker.remove()


embedding_sql = f"""
Expand Down
2 changes: 1 addition & 1 deletion backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from apps.db.engine import get_engine_config, get_engine_conn
from common.core.config import settings
from common.core.deps import SessionDep, CurrentUser, Trans
from apps.datasource.crud.table import run_save_table_embeddings
from common.utils.embedding_threads import run_save_table_embeddings
from common.utils.utils import deepcopy_ignore_extra
from .table import get_tables_by_ds_id
from ..crud.field import delete_field_by_ds_id, update_field
Expand Down
31 changes: 8 additions & 23 deletions backend/apps/datasource/crud/table.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
import json
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import List

from sqlalchemy import and_, select, update
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session

from apps.ai_model.embedding import EmbeddingModelCache
from common.core.config import settings
from common.core.deps import SessionDep
from common.utils.utils import SQLBotLogUtil
from ..models.datasource import CoreTable, CoreField

executor = ThreadPoolExecutor(max_workers=200)

from common.core.db import engine

session_maker = sessionmaker(bind=engine)
session = session_maker()


def delete_table_by_ds_id(session: SessionDep, id: int):
session.query(CoreTable).filter(CoreTable.ds_id == id).delete(synchronize_session=False)
Expand All @@ -40,22 +30,25 @@ def update_table(session: SessionDep, item: CoreTable):
session.commit()


def run_fill_empty_table_embedding(session: Session):
def run_fill_empty_table_embedding(session_maker):
try:
if not settings.TABLE_EMBEDDING_ENABLED:
return

SQLBotLogUtil.info('get tables')
session = session_maker()
stmt = select(CoreTable.id).where(and_(CoreTable.embedding.is_(None)))
results = session.execute(stmt).scalars().all()
SQLBotLogUtil.info('result: ' + str(len(results)))

save_table_embedding(session, results)
except Exception:
traceback.print_exc()
finally:
session_maker.remove()


def save_table_embedding(session: Session, ids: List[int]):
def save_table_embedding(session_maker, ids: List[int]):
if not settings.TABLE_EMBEDDING_ENABLED:
return

Expand All @@ -65,6 +58,7 @@ def save_table_embedding(session: Session, ids: List[int]):
SQLBotLogUtil.info('start table embedding')
start_time = time.time()
model = EmbeddingModelCache.get_model()
session = session_maker()
for _id in ids:
table = session.query(CoreTable).filter(CoreTable.id == _id).first()
fields = session.query(CoreField).filter(CoreField.table_id == table.id).all()
Expand Down Expand Up @@ -102,14 +96,5 @@ def save_table_embedding(session: Session, ids: List[int]):
SQLBotLogUtil.info('table embedding finished in: ' + str(end_time - start_time) + ' seconds')
except Exception:
traceback.print_exc()


def run_save_table_embeddings(ids: List[int]):
executor.submit(save_table_embedding, session, ids)


def fill_empty_table_embeddings():
try:
executor.submit(run_fill_empty_table_embedding, session)
except Exception:
traceback.print_exc()
finally:
session_maker.remove()
35 changes: 23 additions & 12 deletions backend/apps/terminology/curd/terminology.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import dicttoxml
from sqlalchemy import and_, or_, select, func, delete, update, union, text, BigInteger
from sqlalchemy.orm import aliased
from sqlalchemy.orm.session import Session

from apps.ai_model.embedding import EmbeddingModelCache
from apps.datasource.models.datasource import CoreDatasource
Expand Down Expand Up @@ -407,26 +406,36 @@ def delete_terminology(session: SessionDep, ids: list[int]):
#
# def fill_empty_embeddings():
# executor.submit(run_fill_empty_embeddings)
# from sqlalchemy import create_engine
# from sqlalchemy.orm import sessionmaker,scoped_session
# engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
# session_maker = scoped_session(sessionmaker(bind=engine))


def run_fill_empty_embeddings(session: Session):
if not settings.EMBEDDING_ENABLED:
return
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
combined_stmt = union(stmt1, stmt2)
results = session.execute(combined_stmt).scalars().all()
save_embeddings(session, results)
def run_fill_empty_embeddings(session_maker):
try:
if not settings.EMBEDDING_ENABLED:
return
session = session_maker()
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
stmt2 = select(Terminology.pid).where(
and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
combined_stmt = union(stmt1, stmt2)
results = session.execute(combined_stmt).scalars().all()
save_embeddings(session_maker, results)
except Exception:
traceback.print_exc()
finally:
session_maker.remove()


def save_embeddings(session: Session, ids: List[int]):
def save_embeddings(session_maker, ids: List[int]):
if not settings.EMBEDDING_ENABLED:
return

if not ids or len(ids) == 0:
return
try:

session = session_maker()
_list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all()

_words_list = [item.word for item in _list]
Expand All @@ -443,6 +452,8 @@ def save_embeddings(session: Session, ids: List[int]):

except Exception:
traceback.print_exc()
finally:
session_maker.remove()


embedding_sql = f"""
Expand Down
32 changes: 21 additions & 11 deletions backend/common/utils/embedding_threads.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
from concurrent.futures import ThreadPoolExecutor
from typing import List

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from common.core.config import settings
from sqlalchemy.orm import sessionmaker, scoped_session

executor = ThreadPoolExecutor(max_workers=200)

engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
session_maker = sessionmaker(bind=engine)
session = session_maker()
from common.core.db import engine

session_maker = scoped_session(sessionmaker(bind=engine))


# session = session_maker()


def run_save_terminology_embeddings(ids: List[int]):
from apps.terminology.curd.terminology import save_embeddings
executor.submit(save_embeddings, session, ids)
executor.submit(save_embeddings, session_maker, ids)


def fill_empty_terminology_embeddings():
from apps.terminology.curd.terminology import run_fill_empty_embeddings
executor.submit(run_fill_empty_embeddings, session)
executor.submit(run_fill_empty_embeddings, session_maker)


def run_save_data_training_embeddings(ids: List[int]):
from apps.data_training.curd.data_training import save_embeddings
executor.submit(save_embeddings, session, ids)
executor.submit(save_embeddings, session_maker, ids)


def fill_empty_data_training_embeddings():
from apps.data_training.curd.data_training import run_fill_empty_embeddings
executor.submit(run_fill_empty_embeddings, session)
executor.submit(run_fill_empty_embeddings, session_maker)


def run_save_table_embeddings(ids: List[int]):
from apps.datasource.crud.table import save_table_embedding
executor.submit(save_table_embedding, session_maker, ids)


def fill_empty_table_embeddings():
from apps.datasource.crud.table import run_fill_empty_table_embedding
executor.submit(run_fill_empty_table_embedding, session_maker)
2 changes: 1 addition & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from alembic import command
from apps.api import api_router
from apps.datasource.crud.table import fill_empty_table_embeddings
from common.utils.embedding_threads import fill_empty_table_embeddings
from apps.system.crud.aimodel_manage import async_model_info
from apps.system.crud.assistant import init_dynamic_cors
from apps.system.middleware.auth import TokenMiddleware
Expand Down