Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions agent-langgraph-long-term-memory/agent_server/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_user_workspace_client,
process_agent_astream_events,
)
from agent_server.utils_agent_memory import agent_memory_tools, read_agent_instructions
from agent_server.utils_memory import (
get_lakebase_access_error_message,
get_user_id,
Expand Down Expand Up @@ -56,6 +57,7 @@ def get_current_time() -> str:
_LAKEBASE_INSTANCE_NAME_RAW = os.getenv("LAKEBASE_INSTANCE_NAME") or None
EMBEDDING_ENDPOINT = "databricks-gte-large-en"
EMBEDDING_DIMS = 1024
UC_VOLUME = os.getenv("UC_VOLUME", "")
LAKEBASE_AUTOSCALING_PROJECT = os.getenv("LAKEBASE_AUTOSCALING_PROJECT") or None
LAKEBASE_AUTOSCALING_BRANCH = os.getenv("LAKEBASE_AUTOSCALING_BRANCH") or None

Expand Down Expand Up @@ -100,7 +102,11 @@ def get_current_time() -> str:
- Trivial or one-off details (e.g., what they ate for lunch, a single troubleshooting step)
- Highly sensitive personal information (health conditions, political affiliation, sexual orientation, \
religion, criminal history) — unless the user explicitly asks you to store it
- Information that could feel intrusive or overly personal to store"""
- Information that could feel intrusive or overly personal to store

## Agent Memory (shared across all users)
- Use save_agent_instruction to save learnings that apply to ALL users: team preferences, process rules, best practices
- Use get_agent_instructions to read the current shared instructions"""


def init_mcp_client(workspace_client: WorkspaceClient) -> DatabricksMultiServerMCPClient:
Expand All @@ -116,8 +122,10 @@ def init_mcp_client(workspace_client: WorkspaceClient) -> DatabricksMultiServerM
)


async def init_agent(store: BaseStore, workspace_client: Optional[WorkspaceClient] = None):
async def init_agent(store: BaseStore, workspace_client: Optional[WorkspaceClient] = None, system_prompt: str = SYSTEM_PROMPT):
tools = [get_current_time] + memory_tools()
if UC_VOLUME:
tools += agent_memory_tools(workspace_client or sp_workspace_client, UC_VOLUME)
# To use MCP server tools instead, replace the line above with:
# mcp_client = init_mcp_client(workspace_client or sp_workspace_client)
# try:
Expand All @@ -128,7 +136,7 @@ async def init_agent(store: BaseStore, workspace_client: Optional[WorkspaceClien
return create_agent(
model=ChatDatabricks(endpoint=LLM_ENDPOINT_NAME),
tools=tools,
system_prompt=SYSTEM_PROMPT,
system_prompt=system_prompt,
store=store,
)

Expand Down Expand Up @@ -175,9 +183,16 @@ async def stream_handler(
if user_id:
config["configurable"]["user_id"] = user_id

# Inject agent-scoped instructions from UC Volume into system prompt
full_prompt = SYSTEM_PROMPT
if UC_VOLUME:
instructions = read_agent_instructions(sp_workspace_client, UC_VOLUME)
if instructions.strip():
full_prompt += f"\n\n## Current Agent Instructions\n{instructions}"

# By default, uses service principal credentials (sp_workspace_client).
# For on-behalf-of user authentication, use get_user_workspace_client() instead.
agent = await init_agent(workspace_client=sp_workspace_client, store=store)
agent = await init_agent(workspace_client=sp_workspace_client, store=store, system_prompt=full_prompt)
async for event in process_agent_astream_events(
agent.astream(messages, config, stream_mode=["updates", "messages"])
):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from io import BytesIO

from databricks.sdk import WorkspaceClient
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool

logger = logging.getLogger(__name__)

MAX_INSTRUCTION_LINES = 50


def _volume_base_path(volume: str) -> str:
return f"/Volumes/{volume.replace('.', '/')}"


def read_agent_instructions(w: WorkspaceClient, volume: str) -> str:
"""Read instructions.md from a UC Volume. Returns empty string if not found."""
path = f"{_volume_base_path(volume)}/instructions.md"
try:
resp = w.files.download(path)
return resp.contents.read().decode("utf-8")
except Exception:
return ""


def agent_memory_tools(workspace_client: WorkspaceClient, volume: str):
@tool
def save_agent_instruction(instruction: str, config: RunnableConfig) -> str:
"""Save a new instruction to the shared agent memory. Use for learnings that
apply to ALL users: team preferences, process rules, best practices."""
current = read_agent_instructions(workspace_client, volume)
lines = [l for l in current.strip().split("\n") if l.strip()] if current.strip() else []
if sum(1 for l in lines if l.startswith("- ")) >= MAX_INSTRUCTION_LINES:
return f"Cannot save — already at {MAX_INSTRUCTION_LINES} instructions."
lines.append(f"- {instruction}")
path = f"{_volume_base_path(volume)}/instructions.md"
workspace_client.files.upload(path, BytesIO(("\n".join(lines) + "\n").encode("utf-8")), overwrite=True)
return f"Saved agent instruction: {instruction}"

@tool
def get_agent_instructions(config: RunnableConfig) -> str:
"""Read the current shared agent instructions."""
content = read_agent_instructions(workspace_client, volume)
return content if content.strip() else "No agent instructions saved yet."

return [save_agent_instruction, get_agent_instructions]
88 changes: 51 additions & 37 deletions agent-openai-agents-sdk/agent_server/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import litellm
import mlflow
from agents import Agent, Runner, function_tool, set_default_openai_api, set_default_openai_client
from agents import Agent, ModelSettings, Runner, function_tool, set_default_openai_api, set_default_openai_client
from agents.tracing import set_trace_processors
from databricks.sdk import WorkspaceClient
from databricks_openai import AsyncDatabricksOpenAI
Expand Down Expand Up @@ -40,41 +40,62 @@ def get_current_time() -> str:
return datetime.now().isoformat()


async def init_mcp_server(workspace_client: WorkspaceClient):
return McpServer(
url=build_mcp_url("/api/2.0/mcp/functions/system/ai", workspace_client=workspace_client),
name="system.ai UC function MCP server",
workspace_client=workspace_client,
MEMORY_MCP_HOST = "https://eng-ml-agent-platform.staging.cloud.databricks.com"
memory_ws_client = WorkspaceClient(host=MEMORY_MCP_HOST, profile="agent-platform")
MEMORY_STORE = "test-embed"

MEMORY_SYSTEM_PROMPT = f"""You are a helpful assistant with long-term memory. You proactively remember things about users.

Always use memory_store="{MEMORY_STORE}" for all memory operations.

## Before every response
1. Call search_memory scope="agent", query="response preferences and procedures" to load shared instructions.
2. Call search_memory scope="user" to check for personal context about the current user.

## Saving memories
Proactively save anything the user shares about themselves (location, role, preferences, interests, etc.) using write_memory. Use scope="user" for personal facts, scope="agent" for shared rules that apply to all users.

## Conversation history
Refer to the current chat history for questions about this session. Only search memory for info from previous sessions."""


async def init_mcp_servers():
memory = McpServer(
url=f"{MEMORY_MCP_HOST}/api/2.0/mcp/sql",
name="memory-mcp",
workspace_client=memory_ws_client,
params={
"headers": {"x-databricks-traffic-id": "testenv://liteswap/jennymemorysa"},
},
)
github = McpServer(
url=f"{MEMORY_MCP_HOST}/api/2.0/mcp/external/github_demo",
name="github-mcp",
workspace_client=memory_ws_client,
)
return memory, github


def create_agent(mcp_servers: list[McpServer] | None = None) -> Agent:
return Agent(
name="Agent",
instructions="You are a helpful assistant.",
name="Code review agent",
instructions=MEMORY_SYSTEM_PROMPT,
model="databricks-gpt-5-2",
tools=[get_current_time],
mcp_servers=mcp_servers or [],
model_settings=ModelSettings(parallel_tool_calls=False),
)


@invoke()
async def invoke_handler(request: ResponsesAgentRequest) -> ResponsesAgentResponse:
if session_id := get_session_id(request):
mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})
# To use MCP server tools, wrap the code below with this async context manager.
# By default, uses service principal credentials via WorkspaceClient().
# For on-behalf-of user authentication, use get_user_workspace_client() instead.
# try:
# async with await init_mcp_server(WorkspaceClient()) as mcp_server:
# agent = create_agent(mcp_servers=[mcp_server])
# except Exception:
# logger.warning("MCP server unavailable. Continuing without MCP tools.", exc_info=True)
# agent = create_agent()
agent = create_agent()
messages = [i.model_dump() for i in request.input]
result = await Runner.run(agent, messages)
return ResponsesAgentResponse(output=[item.to_input_item() for item in result.new_items])
memory_srv, github_srv = await init_mcp_servers()
async with memory_srv as mem, github_srv as gh:
agent = create_agent(mcp_servers=[mem, gh])
messages = [i.model_dump() for i in request.input]
result = await Runner.run(agent, messages)
return ResponsesAgentResponse(output=[item.to_input_item() for item in result.new_items])


@stream()
Expand All @@ -83,18 +104,11 @@ async def stream_handler(
) -> AsyncGenerator[ResponsesAgentStreamEvent, None]:
if session_id := get_session_id(request):
mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})
# To use MCP server tools, wrap the code below with this async context manager.
# By default, uses service principal credentials via WorkspaceClient().
# For on-behalf-of user authentication, use get_user_workspace_client() instead.
# try:
# async with await init_mcp_server(WorkspaceClient()) as mcp_server:
# agent = create_agent(mcp_servers=[mcp_server])
# except Exception:
# logger.warning("MCP server unavailable. Continuing without MCP tools.", exc_info=True)
# agent = create_agent()
agent = create_agent()
messages = [i.model_dump() for i in request.input]
result = Runner.run_streamed(agent, input=messages)

async for event in process_agent_stream_events(result.stream_events()):
yield event
memory_srv, github_srv = await init_mcp_servers()
async with memory_srv as mem, github_srv as gh:
agent = create_agent(mcp_servers=[mem, gh])
messages = [i.model_dump() for i in request.input]
result = Runner.run_streamed(agent, input=messages)

async for event in process_agent_stream_events(result.stream_events()):
yield event