Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Role,
TextContent,
UriContent,
normalize_messages,
prepend_agent_framework_to_user_agent,
)
from agent_framework.observability import use_agent_instrumentation
Expand Down Expand Up @@ -236,7 +237,7 @@ async def run_stream(
Yields:
An agent response item.
"""
messages = self._normalize_messages(messages)
messages = normalize_messages(messages)
a2a_message = self._prepare_message_for_a2a(messages[-1])

response_stream = self.client.send_message(a2a_message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ContextProvider,
Role,
TextContent,
normalize_messages,
)
from agent_framework._pydantic import AFBaseSettings
from agent_framework.exceptions import ServiceException, ServiceInitializationError
Expand Down Expand Up @@ -237,7 +238,7 @@ async def run(
thread = self.get_new_thread()
thread.service_thread_id = await self._start_new_conversation()

input_messages = self._normalize_messages(messages)
input_messages = normalize_messages(messages)

question = "\n".join([message.text for message in input_messages])

Expand Down Expand Up @@ -278,7 +279,7 @@ async def run_stream(
thread = self.get_new_thread()
thread.service_thread_id = await self._start_new_conversation()

input_messages = self._normalize_messages(messages)
input_messages = normalize_messages(messages)

question = "\n".join([message.text for message in input_messages])

Expand Down
21 changes: 3 additions & 18 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ChatMessage,
ChatResponse,
ChatResponseUpdate,
Role,
normalize_messages,
)
from .exceptions import AgentExecutionException, AgentInitializationError
from .observability import use_agent_instrumentation
Expand Down Expand Up @@ -498,21 +498,6 @@ async def agent_wrapper(**kwargs: Any) -> str:
agent_tool._forward_runtime_kwargs = True # type: ignore
return agent_tool

def _normalize_messages(
self,
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
) -> list[ChatMessage]:
if messages is None:
return []

if isinstance(messages, str):
return [ChatMessage(role=Role.USER, text=messages)]

if isinstance(messages, ChatMessage):
return [messages]

return [ChatMessage(role=Role.USER, text=msg) if isinstance(msg, str) else msg for msg in messages]


# region ChatAgent

Expand Down Expand Up @@ -797,7 +782,7 @@ async def run(
# Get tools from options or named parameter (named param takes precedence)
tools_ = tools if tools is not None else opts.pop("tools", None)

input_messages = self._normalize_messages(messages)
input_messages = normalize_messages(messages)
thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages(
thread=thread, input_messages=input_messages, **kwargs
)
Expand Down Expand Up @@ -925,7 +910,7 @@ async def run_stream(
# Get tools from options or named parameter (named param takes precedence)
tools_ = tools if tools is not None else opts.pop("tools", None)

input_messages = self._normalize_messages(messages)
input_messages = normalize_messages(messages)
thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages(
thread=thread, input_messages=input_messages, **kwargs
)
Expand Down
6 changes: 3 additions & 3 deletions python/packages/core/agent_framework/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypedDict, TypeVar

from ._serialization import SerializationMixin
from ._types import AgentResponse, AgentResponseUpdate, ChatMessage, prepare_messages
from ._types import AgentResponse, AgentResponseUpdate, ChatMessage, normalize_messages, prepare_messages
from .exceptions import MiddlewareException

if TYPE_CHECKING:
Expand Down Expand Up @@ -1225,7 +1225,7 @@ async def middleware_enabled_run(
if chat_middlewares:
kwargs["middleware"] = chat_middlewares

normalized_messages = self._normalize_messages(messages)
normalized_messages = normalize_messages(messages)

# Execute with middleware if available
if agent_pipeline.has_middlewares:
Expand Down Expand Up @@ -1273,7 +1273,7 @@ def middleware_enabled_run_stream(
if chat_middlewares:
kwargs["middleware"] = chat_middlewares

normalized_messages = self._normalize_messages(messages)
normalized_messages = normalize_messages(messages)

# Execute with middleware if available
if agent_pipeline.has_middlewares:
Expand Down
17 changes: 17 additions & 0 deletions python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"UsageContent",
"UsageDetails",
"merge_chat_options",
"normalize_messages",
"normalize_tools",
"prepare_function_call_results",
"prepend_instructions_to_messages",
Expand Down Expand Up @@ -2495,6 +2496,22 @@ def prepare_messages(
return return_messages


def normalize_messages(
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
) -> list[ChatMessage]:
"""Normalize message inputs to a list of ChatMessage objects."""
if messages is None:
return []

if isinstance(messages, str):
return [ChatMessage(role=Role.USER, text=messages)]

if isinstance(messages, ChatMessage):
return [messages]

return [ChatMessage(role=Role.USER, text=msg) if isinstance(msg, str) else msg for msg in messages]


def prepend_instructions_to_messages(
messages: list[ChatMessage],
instructions: str | Sequence[str] | None,
Expand Down
56 changes: 56 additions & 0 deletions python/packages/core/tests/core/test_middleware_with_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,3 +1902,59 @@ async def kwargs_middleware(
assert modified_kwargs["max_tokens"] == 500
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value" # Should still be there


class TestMiddlewareWithProtocolOnlyAgent:
"""Test use_agent_middleware with agents implementing only AgentProtocol."""

async def test_middleware_with_protocol_only_agent(self) -> None:
"""Verify middleware works without BaseAgent inheritance for both run and run_stream."""
from collections.abc import AsyncIterable

from agent_framework import AgentProtocol, AgentResponse, AgentResponseUpdate, use_agent_middleware

execution_order: list[str] = []

class TrackingMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("before")
await next(context)
execution_order.append("after")

@use_agent_middleware
class ProtocolOnlyAgent:
"""Minimal agent implementing only AgentProtocol, not inheriting from BaseAgent."""

def __init__(self):
self.id = "protocol-only-agent"
self.name = "Protocol Only Agent"
self.description = "Test agent"
self.middleware = [TrackingMiddleware()]

async def run(self, messages=None, *, thread=None, **kwargs) -> AgentResponse:
return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])

def run_stream(self, messages=None, *, thread=None, **kwargs) -> AsyncIterable[AgentResponseUpdate]:
async def _stream():
yield AgentResponseUpdate()

return _stream()

def get_new_thread(self, **kwargs):
return None

agent = ProtocolOnlyAgent()
assert isinstance(agent, AgentProtocol)

# Test run (non-streaming)
response = await agent.run("test message")
assert response is not None
assert execution_order == ["before", "after"]

# Test run_stream (streaming)
execution_order.clear()
async for _ in agent.run_stream("test message"):
pass
assert execution_order == ["before", "after"]
Loading