Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/packages/core/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ agent_framework/
- **`AgentMiddleware`** - Intercepts agent `run()` calls
- **`ChatMiddleware`** - Intercepts chat client `get_response()` calls
- **`FunctionMiddleware`** - Intercepts function/tool invocations
- **`AgentRunContext`** / **`ChatContext`** / **`FunctionInvocationContext`** - Context objects passed through middleware
- **`AgentContext`** / **`ChatContext`** / **`FunctionInvocationContext`** - Context objects passed through middleware

### Threads (`_threads.py`)

Expand Down Expand Up @@ -114,10 +114,10 @@ agent = OpenAIChatClient().as_agent(
### Middleware Pipeline

```python
from agent_framework import ChatAgent, AgentMiddleware, AgentRunContext
from agent_framework import ChatAgent, AgentMiddleware, AgentContext

class LoggingMiddleware(AgentMiddleware):
async def invoke(self, context: AgentRunContext, next) -> AgentResponse:
async def process(self, context: AgentContext, next) -> AgentResponse:
print(f"Input: {context.messages}")
response = await next(context)
print(f"Output: {response}")
Expand Down
44 changes: 22 additions & 22 deletions python/packages/core/agent_framework/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@
TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel)

__all__ = [
"AgentContext",
"AgentMiddleware",
"AgentMiddlewareLayer",
"AgentMiddlewareTypes",
"AgentRunContext",
"ChatAndFunctionMiddlewareTypes",
"ChatContext",
"ChatMiddleware",
Expand Down Expand Up @@ -109,7 +109,7 @@ class MiddlewareType(str, Enum):
CHAT = "chat"


class AgentRunContext:
class AgentContext:
"""Context object for agent middleware invocations.

This context is passed through the agent middleware pipeline and contains all information
Expand All @@ -131,11 +131,11 @@ class AgentRunContext:
Examples:
.. code-block:: python

from agent_framework import AgentMiddleware, AgentRunContext
from agent_framework import AgentMiddleware, AgentContext


class LoggingMiddleware(AgentMiddleware):
async def process(self, context: AgentRunContext, next):
async def process(self, context: AgentContext, next):
print(f"Agent: {context.agent.name}")
print(f"Messages: {len(context.messages)}")
print(f"Thread: {context.thread}")
Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(
| None = None,
stream_cleanup_hooks: Sequence[Callable[[], Awaitable[None] | None]] | None = None,
) -> None:
"""Initialize the AgentRunContext.
"""Initialize the AgentContext.

Args:
agent: The agent being invoked.
Expand Down Expand Up @@ -356,14 +356,14 @@ class AgentMiddleware(ABC):
Examples:
.. code-block:: python

from agent_framework import AgentMiddleware, AgentRunContext, ChatAgent
from agent_framework import AgentMiddleware, AgentContext, ChatAgent


class RetryMiddleware(AgentMiddleware):
def __init__(self, max_retries: int = 3):
self.max_retries = max_retries

async def process(self, context: AgentRunContext, next):
async def process(self, context: AgentContext, next):
for attempt in range(self.max_retries):
await next(context)
if context.result and not context.result.is_error:
Expand All @@ -378,8 +378,8 @@ async def process(self, context: AgentRunContext, next):
@abstractmethod
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
context: AgentContext,
next: Callable[[AgentContext], Awaitable[None]],
) -> None:
"""Process an agent invocation.

Expand Down Expand Up @@ -531,7 +531,7 @@ async def process(


# Pure function type definitions for convenience
AgentMiddlewareCallable = Callable[[AgentRunContext, Callable[[AgentRunContext], Awaitable[None]]], Awaitable[None]]
AgentMiddlewareCallable = Callable[[AgentContext, Callable[[AgentContext], Awaitable[None]]], Awaitable[None]]
AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable

FunctionMiddlewareCallable = Callable[
Expand Down Expand Up @@ -561,7 +561,7 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable:
"""Decorator to mark a function as agent middleware.

This decorator explicitly identifies a function as agent middleware,
which processes AgentRunContext objects.
which processes AgentContext objects.

Args:
func: The middleware function to mark as agent middleware.
Expand All @@ -572,11 +572,11 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable:
Examples:
.. code-block:: python

from agent_framework import agent_middleware, AgentRunContext, ChatAgent
from agent_framework import agent_middleware, AgentContext, ChatAgent


@agent_middleware
async def logging_middleware(context: AgentRunContext, next):
async def logging_middleware(context: AgentContext, next):
print(f"Before: {context.agent.name}")
await next(context)
print(f"After: {context.result}")
Expand Down Expand Up @@ -752,9 +752,9 @@ def _register_middleware(self, middleware: AgentMiddlewareTypes) -> None:

async def execute(
self,
context: AgentRunContext,
context: AgentContext,
final_handler: Callable[
[AgentRunContext], Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]
[AgentContext], Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]
],
) -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None:
"""Execute the agent middleware pipeline for streaming or non-streaming.
Expand All @@ -772,17 +772,17 @@ async def execute(
context.result = await context.result
return context.result

def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]:
def create_next_handler(index: int) -> Callable[[AgentContext], Awaitable[None]]:
if index >= len(self._middleware):

async def final_wrapper(c: AgentRunContext) -> None:
async def final_wrapper(c: AgentContext) -> None:
c.result = final_handler(c) # type: ignore[assignment]
if inspect.isawaitable(c.result):
c.result = await c.result

return final_wrapper

async def current_handler(c: AgentRunContext) -> None:
async def current_handler(c: AgentContext) -> None:
# MiddlewareTermination bubbles up to execute() to skip post-processing
await self._middleware[index].process(c, create_next_handler(index + 1))

Expand Down Expand Up @@ -1161,7 +1161,7 @@ def run(
if not pipeline.has_middlewares:
return super().run(messages, stream=stream, thread=thread, options=options, **combined_kwargs) # type: ignore[misc, no-any-return]

context = AgentRunContext(
context = AgentContext(
agent=self, # type: ignore[arg-type]
messages=prepare_messages(messages), # type: ignore[arg-type]
thread=thread,
Expand Down Expand Up @@ -1194,7 +1194,7 @@ async def _execute_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse
return _execute() # type: ignore[return-value]

def _middleware_handler(
self, context: AgentRunContext
self, context: AgentContext
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
return super().run( # type: ignore[misc, no-any-return]
context.messages,
Expand Down Expand Up @@ -1231,7 +1231,7 @@ def _determine_middleware_type(middleware: Any) -> MiddlewareType:
first_param = params[0]
if hasattr(first_param.annotation, "__name__"):
annotation_name = first_param.annotation.__name__
if annotation_name == "AgentRunContext":
if annotation_name == "AgentContext":
param_type = MiddlewareType.AGENT
elif annotation_name == "FunctionInvocationContext":
param_type = MiddlewareType.FUNCTION
Expand Down Expand Up @@ -1270,7 +1270,7 @@ def _determine_middleware_type(middleware: Any) -> MiddlewareType:
raise MiddlewareException(
f"Cannot determine middleware type for function {middleware.__name__}. "
f"Please either use @agent_middleware/@function_middleware/@chat_middleware decorators "
f"or specify parameter types (AgentRunContext, FunctionInvocationContext, or ChatContext)."
f"or specify parameter types (AgentContext, FunctionInvocationContext, or ChatContext)."
)


Expand Down
10 changes: 5 additions & 5 deletions python/packages/core/agent_framework/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,12 @@ async def get_current_weather(location: Annotated[str, "The city name"]) -> str:

.. code-block:: python

from agent_framework._middleware import AgentRunContext
from agent_framework._middleware import AgentContext
from agent_framework import BaseAgent

# AgentRunContext has INJECTABLE = {"agent", "result"}
# AgentContext has INJECTABLE = {"agent", "result"}
context_data = {
"type": "agent_run_context",
"type": "agent_context",
"messages": [{"role": "user", "text": "Hello"}],
"stream": False,
"metadata": {"session_id": "abc123"},
Expand All @@ -492,14 +492,14 @@ async def get_current_weather(location: Annotated[str, "The city name"]) -> str:
# Inject agent and result during middleware processing
my_agent = BaseAgent(name="test-agent")
dependencies = {
"agent_run_context": {
"agent_context": {
"agent": my_agent,
"result": None, # Will be populated during execution
}
}

# Reconstruct context with agent dependency for middleware chain
context = AgentRunContext.from_dict(context_data, dependencies=dependencies)
context = AgentContext.from_dict(context_data, dependencies=dependencies)
# MiddlewareTypes can now access context.agent and process the execution

This injection system allows the agent framework to maintain clean separation
Expand Down
30 changes: 8 additions & 22 deletions python/packages/core/tests/core/test_as_tool_kwargs_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any

from agent_framework import ChatAgent, ChatMessage, ChatResponse, Content, agent_middleware
from agent_framework._middleware import AgentRunContext
from agent_framework._middleware import AgentContext

from .conftest import MockChatClient

Expand All @@ -19,9 +19,7 @@ async def test_as_tool_forwards_runtime_kwargs(self, chat_client: MockChatClient
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
# Capture kwargs passed to the sub-agent
captured_kwargs.update(context.kwargs)
await next(context)
Expand Down Expand Up @@ -62,9 +60,7 @@ async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, chat_client
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await next(context)

Expand Down Expand Up @@ -99,9 +95,7 @@ async def test_as_tool_nested_delegation_propagates_kwargs(self, chat_client: Mo
captured_kwargs_list: list[dict[str, Any]] = []

@agent_middleware
async def capture_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
# Capture kwargs at each level
captured_kwargs_list.append(dict(context.kwargs))
await next(context)
Expand Down Expand Up @@ -162,9 +156,7 @@ async def test_as_tool_streaming_mode_forwards_kwargs(self, chat_client: MockCha
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await next(context)

Expand Down Expand Up @@ -224,9 +216,7 @@ async def test_as_tool_kwargs_with_chat_options(self, chat_client: MockChatClien
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await next(context)

Expand Down Expand Up @@ -266,9 +256,7 @@ async def test_as_tool_kwargs_isolated_per_invocation(self, chat_client: MockCha
call_count = 0

@agent_middleware
async def capture_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
nonlocal call_count
call_count += 1
if call_count == 1:
Expand Down Expand Up @@ -318,9 +306,7 @@ async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, chat
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await next(context)

Expand Down
Loading
Loading