Skip to content

Commit 51e7247

Browse files
committed
fix: Fix the data source-related issues in the PR
1 parent b0fd439 commit 51e7247

6 files changed

Lines changed: 61 additions & 54 deletions

File tree

backend/apps/chat/task/llm.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def choose_table_schema(self, _session: Session):
380380
operate=OperationEnum.CHOOSE_TABLE,
381381
record_id=self.record.id,
382382
local_operation=True)
383-
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
383+
self.chat_question.db_schema, tables = self.out_ds_instance.get_db_schema(
384384
self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema(
385385
session=_session,
386386
current_user=self.current_user,
@@ -392,7 +392,8 @@ def choose_table_schema(self, _session: Session):
392392
self.chat_question.sample_data = get_tables_sample_data(
393393
session=_session,
394394
current_user=self.current_user,
395-
ds=self.ds)
395+
ds=self.ds,
396+
table_list=tables)
396397

397398
self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session,
398399
log=self.current_logs[OperationEnum.CHOOSE_TABLE],
@@ -508,19 +509,19 @@ def generate_recommend_questions_task(self, _session: Session):
508509

509510
# get schema
510511
if self.ds and not self.chat_question.db_schema:
511-
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
512+
self.chat_question.db_schema, tables = self.out_ds_instance.get_db_schema(
512513
self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema(
513514
session=_session,
514515
current_user=self.current_user, ds=self.ds,
515516
question=self.chat_question.question,
516517
embedding=False)
517518

518519
# Get sample data for all tables
519-
if not self.out_ds_instance:
520-
self.chat_question.sample_data = get_tables_sample_data(
521-
session=_session,
522-
current_user=self.current_user,
523-
ds=self.ds)
520+
# if not self.out_ds_instance:
521+
# self.chat_question.sample_data = get_tables_sample_data(
522+
# session=_session,
523+
# current_user=self.current_user,
524+
# ds=self.ds)
524525

525526
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
526527
guess_msg.append(SystemPromptMessage(content=self.chat_question.guess_sys_question(self.articles_number)))
@@ -1356,7 +1357,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
13561357
return
13571358

13581359
# generate chart
1359-
used_tables_schema = self.out_ds_instance.get_db_schema(
1360+
used_tables_schema, used_tables = self.out_ds_instance.get_db_schema(
13601361
self.ds.id, self.chat_question.question, embedding=False,
13611362
table_list=tables) if self.out_ds_instance else get_table_schema(
13621363
session=_session,

backend/apps/datasource/crud/datasource.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def check_name(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDat
6565
if ds_list is not None and len(ds_list) > 0:
6666
raise HTTPException(status_code=500, detail=trans('i18n_ds_name_exist'))
6767

68+
6869
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="user.oid")
6970
async def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: CreateDatasource):
7071
ds = CoreDatasource()
@@ -490,7 +491,7 @@ def get_table_sample_data(ds: CoreDatasource, table_name: str, fields: list) ->
490491

491492

492493
def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource,
493-
table_list: list[str] = None) -> str:
494+
table_list: list[str] = None) -> str:
494495
"""Get sample data (3 rows) for all tables to help AI understand the data"""
495496
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
496497
if len(table_objs) == 0:
@@ -508,15 +509,16 @@ def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: C
508509

509510

510511
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
511-
embedding: bool = True, table_list: list[str] = None) -> str:
512+
embedding: bool = True, table_list: list[str] = None) -> tuple[str, list]:
512513
schema_str = ""
513514
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
514515
if len(table_objs) == 0:
515-
return schema_str
516+
return schema_str, []
516517
db_name = table_objs[0].schema
517518
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
518519
tables = []
519520
all_tables = [] # temp save all tables
521+
table_name_list = []
520522
for obj in table_objs:
521523
# 如果传入了table_list,则只处理在列表中的表
522524
if table_list is not None and obj.table.table_name not in table_list:
@@ -546,13 +548,14 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
546548
schema_table += ",\n".join(field_list)
547549
schema_table += '\n]\n'
548550

549-
t_obj = {"id": obj.table.id, "schema_table": schema_table, "embedding": obj.table.embedding}
551+
t_obj = {"id": obj.table.id, "table_name": obj.table.table_name, "schema_table": schema_table,
552+
"embedding": obj.table.embedding}
550553
tables.append(t_obj)
551554
all_tables.append(t_obj)
552555

553556
# 如果没有符合过滤条件的表,直接返回
554557
if not tables:
555-
return schema_str
558+
return schema_str, []
556559

557560
# do table embedding
558561
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
@@ -561,6 +564,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
561564
if tables:
562565
for s in tables:
563566
schema_str += s.get('schema_table')
567+
table_name_list.append(s.get('table_name'))
564568

565569
# field relation
566570
if tables and ds.table_relation:
@@ -592,6 +596,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
592596
if lost_tables:
593597
for s in lost_tables:
594598
schema_str += s.get('schema_table')
599+
table_name_list.append(s.get('table_name'))
595600

596601
# get field dict
597602
relation_field_ids = []
@@ -609,13 +614,16 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
609614
for ele in all_relations:
610615
schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n"
611616

612-
return schema_str
617+
return schema_str, table_name_list
618+
613619

614620
@cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid")
615621
async def get_ws_ds(session, oid) -> list:
616622
stmt = select(CoreDatasource.id).distinct().where(CoreDatasource.oid == oid)
617623
db_list = session.exec(stmt).all()
618624
return db_list
625+
626+
619627
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid")
620628
async def clear_ws_ds_cache(oid):
621-
SQLBotLogUtil.info(f"ds cache for ws [{oid}] has been cleaned")
629+
SQLBotLogUtil.info(f"ds cache for ws [{oid}] has been cleaned")

backend/apps/datasource/embedding/ds_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
2323
if out_ds.ds_list:
2424
for _ds in out_ds.ds_list:
2525
ds = out_ds.get_ds(_ds.id)
26-
table_schema = out_ds.get_db_schema(_ds.id, question, embedding=False)
26+
table_schema, tables = out_ds.get_db_schema(_ds.id, question, embedding=False)
2727
ds_info = f"{ds.name}, {ds.description}\n"
2828
ds_schema = ds_info + table_schema
2929
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})

backend/apps/datasource/embedding/table_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def calc_table_embedding(tables: list[dict], question: str):
4545
for table in tables:
4646
_list.append(
4747
{"id": table.get('id'), "schema_table": table.get('schema_table'), "embedding": table.get('embedding'),
48-
"cosine_similarity": 0.0})
48+
"cosine_similarity": 0.0, "table_name": table.get('table_name')})
4949

5050
if _list:
5151
try:
@@ -70,7 +70,7 @@ def calc_table_embedding(tables: list[dict], question: str):
7070
end_time = time.time()
7171
SQLBotLogUtil.info(str(end_time - start_time))
7272
SQLBotLogUtil.info(json.dumps([{"id": ele.get('id'), "schema_table": ele.get('schema_table'),
73-
"cosine_similarity": ele.get('cosine_similarity')}
73+
"cosine_similarity": ele.get('cosine_similarity'), "table_name": ele.get('table_name')}
7474
for ele in _list]))
7575
return _list
7676
except Exception:

backend/apps/system/crud/assistant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def get_simple_ds_list(self):
181181
raise Exception("Datasource list is not found.")
182182

183183
def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
184-
table_list: list[str] = None) -> str:
184+
table_list: list[str] = None) -> tuple[str, list]:
185185
ds = self.get_ds(ds_id)
186186
schema_str = ""
187187
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
@@ -222,7 +222,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
222222
for s in tables:
223223
schema_str += s.get('schema_table')
224224

225-
return schema_str
225+
return schema_str, []
226226

227227
def get_ds(self, ds_id: int, trans: Trans = None):
228228
if self.ds_list:

frontend/src/views/chat/execution-component/LogWithAi.vue

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -51,46 +51,44 @@ const recordsBeforeCurrentQuestion = computed(() =>
5151
<template v-if="item.error">
5252
{{ error }}
5353
</template>
54-
<template v-else>
55-
<div class="item-list">
56-
<div class="inner-title">{{ t('chat.log_system') }}</div>
57-
<div class="inner-item">
58-
<div class="inner-item-title">
59-
{{ systemRecord.type }}
60-
</div>
61-
<div class="inner-item-description">
62-
<SQLComponent :sql="systemRecord.content" />
63-
</div>
54+
<div class="item-list">
55+
<div class="inner-title">{{ t('chat.log_system') }}</div>
56+
<div class="inner-item">
57+
<div class="inner-item-title">
58+
{{ systemRecord.type }}
59+
</div>
60+
<div class="inner-item-description">
61+
<SQLComponent :sql="systemRecord.content" />
6462
</div>
65-
<template v-if="recordsBeforeCurrentQuestion.length > 0">
66-
<div class="inner-title">{{ t('chat.log_history') }}</div>
67-
<div class="inner-item">
68-
<div v-for="(ele, index) in recordsBeforeCurrentQuestion" :key="index">
69-
<div class="inner-item-title">
70-
{{ ele.type }}
71-
</div>
72-
<div class="inner-item-description">
73-
<SQLComponent :sql="ele.content" />
74-
</div>
63+
</div>
64+
<template v-if="recordsBeforeCurrentQuestion.length > 0">
65+
<div class="inner-title">{{ t('chat.log_history') }}</div>
66+
<div class="inner-item">
67+
<div v-for="(ele, index) in recordsBeforeCurrentQuestion" :key="index">
68+
<div class="inner-item-title">
69+
{{ ele.type }}
70+
</div>
71+
<div class="inner-item-description">
72+
<SQLComponent :sql="ele.content" />
7573
</div>
7674
</div>
77-
</template>
78-
<div class="inner-title">{{ t('chat.log_question') }}</div>
75+
</div>
76+
</template>
77+
<div class="inner-title">{{ t('chat.log_question') }}</div>
78+
<div class="inner-item">
79+
<div class="inner-item-description">
80+
<SQLComponent :sql="lastHumanRecord.content" />
81+
</div>
82+
</div>
83+
<template v-if="lastAiAfterHuman">
84+
<div class="inner-title">{{ t('chat.log_answer') }}</div>
7985
<div class="inner-item">
8086
<div class="inner-item-description">
81-
<SQLComponent :sql="lastHumanRecord.content" />
87+
<SQLComponent :sql="lastAiAfterHuman.content" />
8288
</div>
8389
</div>
84-
<template v-if="lastAiAfterHuman">
85-
<div class="inner-title">{{ t('chat.log_answer') }}</div>
86-
<div class="inner-item">
87-
<div class="inner-item-description">
88-
<SQLComponent :sql="lastAiAfterHuman.content" />
89-
</div>
90-
</div>
91-
</template>
92-
</div>
93-
</template>
90+
</template>
91+
</div>
9492
</BaseContent>
9593
</template>
9694

0 commit comments

Comments
 (0)