Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,12 @@ def select_datasource(self, _session: Session):

self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question, oid,
ds_id)
self.chat_question.data_training = get_training_template(_session, self.chat_question.question, ds_id,
oid)
if self.current_assistant and self.current_assistant.type == 1:
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
oid, None, self.current_assistant.id)
else:
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
oid, ds_id)
if SQLBotLicenseUtil.valid():
self.chat_question.custom_prompt = find_custom_prompts(_session, CustomPromptTypeEnum.GENERATE_SQL,
oid, ds_id)
Expand Down Expand Up @@ -902,8 +906,12 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question,
oid, ds_id)
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
ds_id, oid)
if self.current_assistant and self.current_assistant.type == 1:
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
oid, None, self.current_assistant.id)
else:
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
oid, ds_id)
if SQLBotLicenseUtil.valid():
self.chat_question.custom_prompt = find_custom_prompts(_session,
CustomPromptTypeEnum.GENERATE_SQL,
Expand Down
109 changes: 82 additions & 27 deletions backend/apps/data_training/curd/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from sqlalchemy import text

from apps.ai_model.embedding import EmbeddingModelCache
from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining
from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining, DataTrainingInfoResult
from apps.datasource.models.datasource import CoreDatasource
from apps.system.models.system_model import AssistantModel
from apps.template.generate_chart.generator import get_base_data_training_template
from common.core.config import settings
from common.core.deps import SessionDep, Trans
Expand All @@ -19,7 +20,7 @@

def page_data_training(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
oid: Optional[int] = 1):
_list: List[DataTrainingInfo] = []
_list: List[DataTrainingInfoResult] = []

current_page = max(1, current_page)
page_size = max(10, page_size)
Expand Down Expand Up @@ -63,40 +64,60 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in
DataTraining.create_time,
DataTraining.description,
DataTraining.enabled,
DataTraining.advanced_application,
AssistantModel.name.label('advanced_application_name'),
)
.outerjoin(CoreDatasource, and_(DataTraining.datasource == CoreDatasource.id))
.outerjoin(AssistantModel,
and_(DataTraining.advanced_application == AssistantModel.id, AssistantModel.type == 1))
.where(and_(DataTraining.id.in_(paginated_parent_ids)))
.order_by(DataTraining.create_time.desc())
)

result = session.execute(stmt)

for row in result:
_list.append(DataTrainingInfo(
id=row.id,
oid=row.oid,
_list.append(DataTrainingInfoResult(
id=str(row.id),
oid=str(row.oid),
datasource=row.datasource,
datasource_name=row.name,
question=row.question,
create_time=row.create_time,
description=row.description,
enabled=row.enabled,
advanced_application=str(row.advanced_application) if row.advanced_application else None,
advanced_application_name=row.advanced_application_name,
))

return current_page, page_size, total_count, total_pages, _list


def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
create_time = datetime.datetime.now()
if info.datasource is None:
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
if info.datasource is None and info.advanced_application is None:
if oid == 1:
raise Exception(trans("i18n_data_training.datasource_assistant_cannot_be_none"))
else:
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))

parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid,
datasource=info.datasource, enabled=info.enabled)
datasource=info.datasource, enabled=info.enabled,
advanced_application=info.advanced_application)

stmt = select(DataTraining.id).where(and_(DataTraining.question == info.question, DataTraining.oid == oid))

if info.datasource is not None and info.advanced_application is not None:
stmt = stmt.where(
or_(DataTraining.datasource == info.datasource,
DataTraining.advanced_application == info.advanced_application))
elif info.datasource is not None and info.advanced_application is None:
stmt = stmt.where(and_(DataTraining.datasource == info.datasource))
elif info.datasource is None and info.advanced_application is not None:
stmt = stmt.where(and_(DataTraining.advanced_application == info.advanced_application))

exists = session.query(stmt.exists()).scalar()

exists = session.query(
session.query(DataTraining).filter(
and_(DataTraining.question == info.question, DataTraining.oid == oid,
DataTraining.datasource == info.datasource)).exists()).scalar()
if exists:
raise Exception(trans("i18n_data_training.exists_in_db"))

Expand All @@ -116,20 +137,32 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans


def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
if info.datasource is None:
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
if info.datasource is None and info.advanced_application is None:
if oid == 1:
raise Exception(trans("i18n_data_training.datasource_assistant_cannot_be_none"))
else:
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))

count = session.query(DataTraining).filter(
DataTraining.id == info.id
).count()
if count == 0:
raise Exception(trans('i18n_data_training.data_training_not_exists'))

exists = session.query(
session.query(DataTraining).filter(
and_(DataTraining.question == info.question, DataTraining.oid == oid,
DataTraining.datasource == info.datasource,
DataTraining.id != info.id)).exists()).scalar()
stmt = select(DataTraining.id).where(
and_(DataTraining.question == info.question, DataTraining.oid == oid, DataTraining.id != info.id))

if info.datasource is not None and info.advanced_application is not None:
stmt = stmt.where(
or_(DataTraining.datasource == info.datasource,
DataTraining.advanced_application == info.advanced_application))
elif info.datasource is not None and info.advanced_application is None:
stmt = stmt.where(and_(DataTraining.datasource == info.datasource))
elif info.datasource is None and info.advanced_application is not None:
stmt = stmt.where(and_(DataTraining.advanced_application == info.advanced_application))

exists = session.query(stmt.exists()).scalar()

if exists:
raise Exception(trans("i18n_data_training.exists_in_db"))

Expand All @@ -138,6 +171,7 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
description=info.description,
datasource=info.datasource,
enabled=info.enabled,
advanced_application=info.advanced_application,
)
session.execute(stmt)
session.commit()
Expand Down Expand Up @@ -231,9 +265,21 @@ def save_embeddings(session_maker, ids: List[int]):
ORDER BY similarity DESC
LIMIT {settings.EMBEDDING_DATA_TRAINING_TOP_COUNT}
"""
embedding_sql_in_advanced_application = f"""
SELECT id, datasource, question, similarity
FROM
(SELECT id, datasource, question, oid, enabled,
( 1 - (embedding <=> :embedding_array) ) AS similarity
FROM data_training AS child
) TEMP
WHERE similarity > {settings.EMBEDDING_DATA_TRAINING_SIMILARITY} and oid = :oid and advanced_application = :advanced_application and enabled = true
ORDER BY similarity DESC
LIMIT {settings.EMBEDDING_DATA_TRAINING_TOP_COUNT}
"""


def select_training_by_question(session: SessionDep, question: str, oid: int, datasource: int):
def select_training_by_question(session: SessionDep, question: str, oid: int, datasource: Optional[int] = None,
advanced_application_id: Optional[int] = None):
if question.strip() == "":
return []

Expand All @@ -248,10 +294,13 @@ def select_training_by_question(session: SessionDep, question: str, oid: int, da
.where(
and_(or_(text(":sentence ILIKE '%' || question || '%'"), text("question ILIKE '%' || :sentence || '%'")),
DataTraining.oid == oid,
DataTraining.datasource == datasource,
DataTraining.enabled == True,)
DataTraining.enabled == True)
)
)
if advanced_application_id is not None:
stmt = stmt.where(and_(DataTraining.advanced_application == advanced_application_id))
else:
stmt = stmt.where(and_(DataTraining.datasource == datasource))

results = session.execute(stmt, {'sentence': question}).fetchall()

Expand All @@ -264,8 +313,13 @@ def select_training_by_question(session: SessionDep, question: str, oid: int, da

embedding = model.embed_query(question)

results = session.execute(text(embedding_sql),
{'embedding_array': str(embedding), 'oid': oid, 'datasource': datasource})
if advanced_application_id is not None:
results = session.execute(text(embedding_sql_in_advanced_application),
{'embedding_array': str(embedding), 'oid': oid,
'advanced_application': advanced_application_id})
else:
results = session.execute(text(embedding_sql),
{'embedding_array': str(embedding), 'oid': oid, 'datasource': datasource})

for row in results:
_list.append(DataTraining(id=row.id, question=row.question))
Expand Down Expand Up @@ -328,12 +382,13 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'sql-examples') -> str:
return pretty_xml


def get_training_template(session: SessionDep, question: str, datasource: int, oid: Optional[int] = 1) -> str:
def get_training_template(session: SessionDep, question: str, oid: Optional[int] = 1, datasource: Optional[int] = None,
advanced_application_id: Optional[int] = None) -> str:
if not oid:
oid = 1
if not datasource:
if not datasource and not advanced_application_id:
return ''
_results = select_training_by_question(session, question, oid, datasource)
_results = select_training_by_question(session, question, oid, datasource, advanced_application_id)
if _results and len(_results) > 0:
data_training = to_xml_string(_results)
template = get_base_data_training_template().format(data_training=data_training)
Expand Down
16 changes: 16 additions & 0 deletions backend/apps/data_training/models/data_training_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class DataTraining(SQLModel, table=True):
description: Optional[str] = Field(sa_column=Column(Text, nullable=True))
embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True))
enabled: Optional[bool] = Field(sa_column=Column(Boolean, default=True))
advanced_application: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True))


class DataTrainingInfo(BaseModel):
Expand All @@ -28,3 +29,18 @@ class DataTrainingInfo(BaseModel):
question: Optional[str] = None
description: Optional[str] = None
enabled: Optional[bool] = True
advanced_application: Optional[int] = None
advanced_application_name: Optional[str] = None


class DataTrainingInfoResult(BaseModel):
id: Optional[str] = None
oid: Optional[str] = None
datasource: Optional[int] = None
datasource_name: Optional[str] = None
create_time: Optional[datetime] = None
question: Optional[str] = None
description: Optional[str] = None
enabled: Optional[bool] = True
advanced_application: Optional[str] = None
advanced_application_name: Optional[str] = None
Loading