Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# from apps.chat.models.chat_model import SQLModel
from apps.terminology.models.terminology_model import SQLModel
#from apps.custom_prompt.models.custom_prompt_model import SQLModel
# from apps.data_training.models.data_training_model import SQLModel
from apps.data_training.models.data_training_model import SQLModel
# from apps.dashboard.models.dashboard_model import SQLModel
from common.core.config import settings # noqa
#from apps.datasource.models.datasource import SQLModel
Expand Down
37 changes: 37 additions & 0 deletions backend/alembic/versions/050_modify_ddl_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""050_modify_ddl.py

Revision ID: 2785e54dc1c4
Revises: b58a71ca6ae3
Create Date: 2025-11-06 13:43:50.820328

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '2785e54dc1c4'
down_revision = 'b58a71ca6ae3'
branch_labels = None
depends_on = None

sql='''
UPDATE data_training SET enabled = true;
UPDATE terminology SET enabled = true;
'''

def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('data_training', sa.Column('enabled', sa.Boolean(), nullable=True))
op.add_column('terminology', sa.Column('enabled', sa.Boolean(), nullable=True))

op.execute(sql)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('terminology', 'enabled')
op.drop_column('data_training', 'enabled')
# ### end Alembic commands ###
8 changes: 7 additions & 1 deletion backend/apps/data_training/api/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from fastapi import APIRouter, Query

from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training
from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \
enable_training
from apps.data_training.models.data_training_model import DataTrainingInfo
from common.core.deps import SessionDep, CurrentUser, Trans

Expand Down Expand Up @@ -37,3 +38,8 @@ async def create_or_update(session: SessionDep, current_user: CurrentUser, trans
@router.delete("")
async def delete(session: SessionDep, id_list: list[int]):
delete_training(session, id_list)


@router.get("{id}/enable/{enabled}")
async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans):
enable_training(session, id, enabled, trans)
17 changes: 16 additions & 1 deletion backend/apps/data_training/curd/data_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
if info.datasource is None:
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid,
datasource=info.datasource)
datasource=info.datasource, enabled=info.enabled)

exists = session.query(
session.query(DataTraining).filter(
Expand Down Expand Up @@ -135,6 +135,7 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
question=info.question,
description=info.description,
datasource=info.datasource,
enabled=info.enabled,
)
session.execute(stmt)
session.commit()
Expand All @@ -151,6 +152,20 @@ def delete_training(session: SessionDep, ids: list[int]):
session.commit()


def enable_training(session: SessionDep, id: int, enabled: bool, trans: Trans):
count = session.query(DataTraining).filter(
DataTraining.id == id
).count()
if count == 0:
raise Exception(trans('i18n_data_training.data_training_not_exists'))

stmt = update(DataTraining).where(and_(DataTraining.id == id)).values(
enabled=enabled,
)
session.execute(stmt)
session.commit()


# def run_save_embeddings(ids: List[int]):
# executor.submit(save_embeddings, ids)
#
Expand Down
4 changes: 3 additions & 1 deletion backend/apps/data_training/models/data_training_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pgvector.sqlalchemy import VECTOR
from pydantic import BaseModel
from sqlalchemy import Column, Text, BigInteger, DateTime, Identity
from sqlalchemy import Column, Text, BigInteger, DateTime, Identity, Boolean
from sqlmodel import SQLModel, Field


Expand All @@ -16,6 +16,7 @@ class DataTraining(SQLModel, table=True):
question: Optional[str] = Field(max_length=255)
description: Optional[str] = Field(sa_column=Column(Text, nullable=True))
embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True))
enabled: Optional[bool] = Field(sa_column=Column(Boolean, default=True))


class DataTrainingInfo(BaseModel):
Expand All @@ -26,3 +27,4 @@ class DataTrainingInfo(BaseModel):
create_time: Optional[datetime] = None
question: Optional[str] = None
description: Optional[str] = None
enabled: Optional[bool] = True
7 changes: 6 additions & 1 deletion backend/apps/terminology/api/terminology.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import APIRouter, Query

from apps.terminology.curd.terminology import page_terminology, create_terminology, update_terminology, \
delete_terminology
delete_terminology, enable_terminology
from apps.terminology.models.terminology_model import TerminologyInfo
from common.core.deps import SessionDep, CurrentUser, Trans

Expand Down Expand Up @@ -37,3 +37,8 @@ async def create_or_update(session: SessionDep, current_user: CurrentUser, trans
@router.delete("")
async def delete(session: SessionDep, id_list: list[int]):
delete_terminology(session, id_list)


@router.get("{id}/enable/{enabled}")
async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans):
enable_terminology(session, id, enabled, trans)
23 changes: 19 additions & 4 deletions backend/apps/terminology/curd/terminology.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
raise Exception(trans("i18n_terminology.datasource_cannot_be_none"))

parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid,
specific_ds=specific_ds,
specific_ds=specific_ds, enabled=info.enabled,
datasource_ids=datasource_ids)

words = [info.word]
Expand Down Expand Up @@ -289,7 +289,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
if other_word.strip() == "":
continue
_list.append(
Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid,
Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid, enabled=result.enabled,
specific_ds=specific_ds, datasource_ids=datasource_ids))
session.bulk_save_objects(_list)
session.flush()
Expand Down Expand Up @@ -366,7 +366,8 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
word=info.word,
description=info.description,
specific_ds=specific_ds,
datasource_ids=datasource_ids
datasource_ids=datasource_ids,
enabled=info.enabled,
)
session.execute(stmt)
session.commit()
Expand All @@ -383,7 +384,7 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
continue
_list.append(
Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid,
specific_ds=specific_ds, datasource_ids=datasource_ids))
specific_ds=specific_ds, datasource_ids=datasource_ids, enabled=info.enabled))
session.bulk_save_objects(_list)
session.flush()
session.commit()
Expand All @@ -400,6 +401,20 @@ def delete_terminology(session: SessionDep, ids: list[int]):
session.commit()


def enable_terminology(session: SessionDep, id: int, enabled: bool, trans: Trans):
count = session.query(Terminology).filter(
Terminology.id == id
).count()
if count == 0:
raise Exception(trans('i18n_terminology.terminology_not_exists'))

stmt = update(Terminology).where(or_(Terminology.id == id, Terminology.pid == id)).values(
enabled=enabled,
)
session.execute(stmt)
session.commit()


# def run_save_embeddings(ids: List[int]):
# executor.submit(save_embeddings, ids)
#
Expand Down
2 changes: 2 additions & 0 deletions backend/apps/terminology/models/terminology_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Terminology(SQLModel, table=True):
embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True))
specific_ds: Optional[bool] = Field(sa_column=Column(Boolean, default=False))
datasource_ids: Optional[list[int]] = Field(sa_column=Column(JSONB), default=[])
enabled: Optional[bool] = Field(sa_column=Column(Boolean, default=True))


class TerminologyInfo(BaseModel):
Expand All @@ -30,3 +31,4 @@ class TerminologyInfo(BaseModel):
specific_ds: Optional[bool] = False
datasource_ids: Optional[list[int]] = []
datasource_names: Optional[list[str]] = []
enabled: Optional[bool] = True