@@ -65,6 +65,7 @@ def check_name(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDat
6565 if ds_list is not None and len (ds_list ) > 0 :
6666 raise HTTPException (status_code = 500 , detail = trans ('i18n_ds_name_exist' ))
6767
68+
6869@clear_cache (namespace = CacheNamespace .AUTH_INFO , cacheName = CacheName .DS_ID_LIST , keyExpression = "user.oid" )
6970async def create_ds (session : SessionDep , trans : Trans , user : CurrentUser , create_ds : CreateDatasource ):
7071 ds = CoreDatasource ()
@@ -490,7 +491,7 @@ def get_table_sample_data(ds: CoreDatasource, table_name: str, fields: list) ->
490491
491492
492493def get_tables_sample_data (session : SessionDep , current_user : CurrentUser , ds : CoreDatasource ,
493- table_list : list [str ] = None ) -> str :
494+ table_list : list [str ] = None ) -> str :
494495 """Get sample data (3 rows) for all tables to help AI understand the data"""
495496 table_objs = get_table_obj_by_ds (session = session , current_user = current_user , ds = ds )
496497 if len (table_objs ) == 0 :
@@ -508,15 +509,16 @@ def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: C
508509
509510
510511def get_table_schema (session : SessionDep , current_user : CurrentUser , ds : CoreDatasource , question : str ,
511- embedding : bool = True , table_list : list [str ] = None ) -> str :
512+ embedding : bool = True , table_list : list [str ] = None ) -> tuple [ str , list ] :
512513 schema_str = ""
513514 table_objs = get_table_obj_by_ds (session = session , current_user = current_user , ds = ds )
514515 if len (table_objs ) == 0 :
515- return schema_str
516+ return schema_str , []
516517 db_name = table_objs [0 ].schema
517518 schema_str += f"【DB_ID】 { db_name } \n 【Schema】\n "
518519 tables = []
519520 all_tables = [] # temp save all tables
521+ table_name_list = []
520522 for obj in table_objs :
521523 # 如果传入了table_list,则只处理在列表中的表
522524 if table_list is not None and obj .table .table_name not in table_list :
@@ -546,13 +548,14 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
546548 schema_table += ",\n " .join (field_list )
547549 schema_table += '\n ]\n '
548550
549- t_obj = {"id" : obj .table .id , "schema_table" : schema_table , "embedding" : obj .table .embedding }
551+ t_obj = {"id" : obj .table .id , "table_name" : obj .table .table_name , "schema_table" : schema_table ,
552+ "embedding" : obj .table .embedding }
550553 tables .append (t_obj )
551554 all_tables .append (t_obj )
552555
553556 # 如果没有符合过滤条件的表,直接返回
554557 if not tables :
555- return schema_str
558+ return schema_str , []
556559
557560 # do table embedding
558561 if embedding and tables and settings .TABLE_EMBEDDING_ENABLED :
@@ -561,6 +564,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
561564 if tables :
562565 for s in tables :
563566 schema_str += s .get ('schema_table' )
567+ table_name_list .append (s .get ('table_name' ))
564568
565569 # field relation
566570 if tables and ds .table_relation :
@@ -592,6 +596,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
592596 if lost_tables :
593597 for s in lost_tables :
594598 schema_str += s .get ('schema_table' )
599+ table_name_list .append (s .get ('table_name' ))
595600
596601 # get field dict
597602 relation_field_ids = []
@@ -609,13 +614,16 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
609614 for ele in all_relations :
610615 schema_str += f"{ table_dict .get (int (ele .get ('source' ).get ('cell' )))} .{ field_dict .get (int (ele .get ('source' ).get ('port' )))} ={ table_dict .get (int (ele .get ('target' ).get ('cell' )))} .{ field_dict .get (int (ele .get ('target' ).get ('port' )))} \n "
611616
612- return schema_str
617+ return schema_str , table_name_list
618+
613619
614620@cache (namespace = CacheNamespace .AUTH_INFO , cacheName = CacheName .DS_ID_LIST , keyExpression = "oid" )
615621async def get_ws_ds (session , oid ) -> list :
616622 stmt = select (CoreDatasource .id ).distinct ().where (CoreDatasource .oid == oid )
617623 db_list = session .exec (stmt ).all ()
618624 return db_list
625+
626+
619627@clear_cache (namespace = CacheNamespace .AUTH_INFO , cacheName = CacheName .DS_ID_LIST , keyExpression = "oid" )
620628async def clear_ws_ds_cache (oid ):
621- SQLBotLogUtil .info (f"ds cache for ws [{ oid } ] has been cleaned" )
629+ SQLBotLogUtil .info (f"ds cache for ws [{ oid } ] has been cleaned" )
0 commit comments