Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions astrbot/builtin_stars/astrbot/long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import random
import uuid
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Literal

from astrbot import logger
from astrbot.api import star
Expand All @@ -11,9 +13,36 @@
from astrbot.api.provider import LLMResponse, Provider, ProviderRequest
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager

"""
聊天记忆增强
"""

@dataclass
class ChatRecord:
"""单条聊天记录,用于长期记忆存储。"""

msg_id: str
"""消息唯一标识(用户消息取 message_obj.message_id,AI 回复取 'ai:<uuid>')"""
role: Literal["user", "assistant"]
"""角色:user 表示用户消息,assistant 表示 AI 回复"""
text: str
"""格式化后的文本,如 '[昵称/HH:MM:SS]: ...' 或 '[You/HH:MM:SS]: ...'"""
created_at: str = field(default_factory=lambda: datetime.datetime.now().isoformat())
"""创建时间(ISO 格式),用于调试/扩展"""


def _get_event_msg_id(event: AstrMessageEvent) -> str:
"""
获取当前事件对应的消息 ID。
以保证同一事件链路(handle_message -> on_req_llm -> after_req_llm)使用同一 ID。
"""
msg_id = getattr(event.message_obj, "message_id", None)
if msg_id:
return str(msg_id)
# fallback: 使用 extra 缓存
cached = event.get_extra("_ltm_msg_id")
if cached:
return cached
generated = f"ltm:{uuid.uuid4().hex}"
event.set_extra("_ltm_msg_id", generated)
return generated


class LongTermMemory:
Expand Down Expand Up @@ -144,7 +173,11 @@ async def handle_message(self, event: AstrMessageEvent):

final_message = "".join(parts)
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
self.session_chats[event.unified_msg_origin].append(final_message)

msg_id = _get_event_msg_id(event)
record = ChatRecord(msg_id=msg_id, role="user", text=final_message)
self.session_chats[event.unified_msg_origin].append(record)

if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
self.session_chats[event.unified_msg_origin].pop(0)

Expand All @@ -153,7 +186,14 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
if event.unified_msg_origin not in self.session_chats:
return

chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin])
current_msg_id = _get_event_msg_id(event)

# 构造历史字符串时按 msg_id 过滤当前轮消息
history_records: list[ChatRecord] = self.session_chats[event.unified_msg_origin]
filtered_texts = [
rec.text for rec in history_records if rec.msg_id != current_msg_id
]
chats_str = "\n---\n".join(filtered_texts)

cfg = self.cfg(event)
if cfg["enable_active_reply"]:
Expand All @@ -180,7 +220,11 @@ async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
logger.debug(
f"Recorded AI response: {event.unified_msg_origin} | {final_message}"
)
self.session_chats[event.unified_msg_origin].append(final_message)

ai_msg_id = f"ai:{uuid.uuid4().hex}"
record = ChatRecord(msg_id=ai_msg_id, role="assistant", text=final_message)
self.session_chats[event.unified_msg_origin].append(record)

cfg = self.cfg(event)
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
self.session_chats[event.unified_msg_origin].pop(0)