5454
5555warnings .filterwarnings ("ignore" )
5656
57- base_message_count_limit = 6
58-
5957executor = ThreadPoolExecutor (max_workers = 200 )
6058
6159dynamic_ds_types = [1 , 3 ]
@@ -95,6 +93,7 @@ class LLMService:
9593 articles_number : int = 4
9694
9795 enable_sql_row_limit : bool = settings .GENERATE_SQL_QUERY_LIMIT_ENABLED
96+ base_message_round_count_limit : int = settings .GENERATE_SQL_QUERY_HISTORY_ROUND_COUNT
9897
9998 def __init__ (self , session : Session , current_user : CurrentUser , chat_question : ChatQuestion ,
10099 current_assistant : Optional [CurrentAssistant ] = None , no_reasoning : bool = False ,
@@ -185,6 +184,14 @@ async def create(cls, *args, **kwargs):
185184 instance .enable_sql_row_limit = True
186185 else :
187186 instance .enable_sql_row_limit = False
187+ if config .pkey == 'chat.context_record_count' :
188+ count_value = config .pval
189+ if count_value is None :
190+ count_value = settings .GENERATE_SQL_QUERY_HISTORY_ROUND_COUNT
191+ count_value = int (count_value )
192+ if count_value < 0 :
193+ count_value = 0
194+ instance .base_message_round_count_limit = count_value
188195 return instance
189196
190197 def is_running (self , timeout = 0.5 ):
@@ -206,22 +213,23 @@ def init_messages(self):
206213 filter (lambda obj : obj .pid == self .chat_question .regenerate_record_id , self .generate_sql_logs ), None )
207214 last_sql_messages : List [dict [str , Any ]] = _temp_log .messages if _temp_log else []
208215
209- # todo maybe can configure
210- count_limit = 0 - base_message_count_limit
216+ count_limit = self .base_message_round_count_limit
211217
212218 self .sql_message = []
213219 # add sys prompt
214220 self .sql_message .append (SystemMessage (
215221 content = self .chat_question .sql_sys_question (self .ds .type , self .enable_sql_row_limit )))
216222 if last_sql_messages is not None and len (last_sql_messages ) > 0 :
217- # limit count
218- for last_sql_message in last_sql_messages [count_limit :]:
223+ # 获取最后3轮对话
224+ last_rounds = get_last_conversation_rounds (last_sql_messages , rounds = count_limit )
225+
226+ for _msg_dict in last_rounds :
219227 _msg : BaseMessage
220- if last_sql_message [ 'type' ] == 'human' :
221- _msg = HumanMessage (content = last_sql_message [ 'content' ] )
228+ if _msg_dict . get ( 'type' ) == 'human' :
229+ _msg = HumanMessage (content = _msg_dict . get ( 'content' ) )
222230 self .sql_message .append (_msg )
223- elif last_sql_message [ 'type' ] == 'ai' :
224- _msg = AIMessage (content = last_sql_message [ 'content' ] )
231+ elif _msg_dict . get ( 'type' ) == 'ai' :
232+ _msg = AIMessage (content = _msg_dict . get ( 'content' ) )
225233 self .sql_message .append (_msg )
226234
227235 last_chart_messages : List [dict [str , Any ]] = self .generate_chart_logs [- 1 ].messages if len (
@@ -1666,3 +1674,29 @@ def get_lang_name(lang: str):
16661674 if normalized .startswith ('ko' ):
16671675 return '韩语'
16681676 return '简体中文'
1677+
1678+
1679+ def get_last_conversation_rounds (messages , rounds = settings .GENERATE_SQL_QUERY_HISTORY_ROUND_COUNT ):
1680+ """获取最后N轮对话,处理不完整对话的情况"""
1681+ if not messages or rounds <= 0 :
1682+ return []
1683+
1684+ # 找到所有用户消息的位置
1685+ human_indices = []
1686+ for index , msg in enumerate (messages ):
1687+ if msg .get ('type' ) == 'human' :
1688+ human_indices .append (index )
1689+
1690+ # 如果没有用户消息,返回空
1691+ if not human_indices :
1692+ return []
1693+
1694+ # 计算从哪个索引开始
1695+ if len (human_indices ) <= rounds :
1696+ # 如果用户消息数少于等于需要的轮数,从第一个用户消息开始
1697+ start_index = human_indices [0 ]
1698+ else :
1699+ # 否则,从倒数第N个用户消息开始
1700+ start_index = human_indices [- rounds ]
1701+
1702+ return messages [start_index :]
0 commit comments