Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions python/packages/a2a/tests/test_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def test_get_uri_data_invalid_uri() -> None:
def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None:
"""Test A2A parts to contents conversion."""

agent = A2AAgent(name="Test Agent", client=MockA2AClient(), _http_client=None)
agent = A2AAgent(name="Test Agent", client=MockA2AClient(), http_client=None)

# Create A2A parts
parts = [Part(root=TextPart(text="First part")), Part(root=TextPart(text="Second part"))]
Expand Down Expand Up @@ -475,7 +475,7 @@ async def test_context_manager_no_cleanup_when_no_http_client() -> None:

mock_a2a_client = MagicMock()

agent = A2AAgent(client=mock_a2a_client, _http_client=None)
agent = A2AAgent(client=mock_a2a_client, http_client=None)

# This should not raise any errors
async with agent:
Expand All @@ -485,7 +485,7 @@ async def test_context_manager_no_cleanup_when_no_http_client() -> None:
def test_prepare_message_for_a2a_with_multiple_contents() -> None:
"""Test conversion of Message with multiple contents."""

agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)

# Create message with multiple content types
message = Message(
Expand Down Expand Up @@ -513,7 +513,7 @@ def test_prepare_message_for_a2a_with_multiple_contents() -> None:
def test_prepare_message_for_a2a_forwards_context_id() -> None:
"""Test conversion of Message preserves context_id without duplicating it in metadata."""

agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)

message = Message(
role="user",
Expand All @@ -530,7 +530,7 @@ def test_prepare_message_for_a2a_forwards_context_id() -> None:
def test_parse_contents_from_a2a_with_data_part() -> None:
"""Test conversion of A2A DataPart."""

agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)

# Create DataPart
data_part = Part(root=DataPart(data={"key": "value", "number": 42}, metadata={"source": "test"}))
Expand All @@ -546,7 +546,7 @@ def test_parse_contents_from_a2a_with_data_part() -> None:

def test_parse_contents_from_a2a_unknown_part_kind() -> None:
"""Test error handling for unknown A2A part kind."""
agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)

# Create a mock part with unknown kind
mock_part = MagicMock()
Expand All @@ -559,7 +559,7 @@ def test_parse_contents_from_a2a_unknown_part_kind() -> None:
def test_prepare_message_for_a2a_with_hosted_file() -> None:
"""Test conversion of Message with HostedFileContent to A2A message."""

agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)

# Create message with hosted file content
message = Message(
Expand All @@ -585,7 +585,7 @@ def test_prepare_message_for_a2a_with_hosted_file() -> None:
def test_parse_contents_from_a2a_with_hosted_file_uri() -> None:
"""Test conversion of A2A FilePart with hosted file URI back to UriContent."""

agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)

# Create FilePart with hosted file URI (simulating what A2A would send back)
file_part = Part(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,16 +436,17 @@ def _to_chat_agent_from_agent(

# Merge tools: convert agent's hosted tools + user-provided function tools
merged_tools = self._merge_tools(agent.tools, provided_tools)
merged_default_options: dict[str, Any] = dict(default_options) if default_options is not None else {}
merged_default_options.setdefault("model_id", agent.model)

return Agent( # type: ignore[return-value]
client=client,
id=agent.id,
name=agent.name,
description=agent.description,
instructions=agent.instructions,
model_id=agent.model,
tools=merged_tools,
default_options=default_options, # type: ignore[arg-type]
default_options=cast(Any, merged_default_options),
middleware=middleware,
context_providers=context_providers,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import sys
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any, Generic
from typing import Any, Generic, cast

from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
Expand Down Expand Up @@ -396,16 +396,17 @@ def _to_chat_agent_from_details(
# from_azure_ai_tools converts hosted tools (MCP, code interpreter, file search, web search)
# but function tools need the actual implementations from provided_tools
merged_tools = self._merge_tools(details.definition.tools, provided_tools)
merged_default_options: dict[str, Any] = dict(default_options) if default_options is not None else {}
merged_default_options.setdefault("model_id", details.definition.model)

return Agent( # type: ignore[return-value]
client=client,
id=details.id,
name=details.name,
description=details.description,
instructions=details.definition.instructions,
model_id=details.definition.model,
tools=merged_tools,
default_options=default_options, # type: ignore[arg-type]
default_options=cast(Any, merged_default_options),
middleware=middleware,
context_providers=context_providers,
)
Expand Down
74 changes: 71 additions & 3 deletions python/packages/claude/agent_framework_claude/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast, overload

from agent_framework import (
AgentMiddlewareTypes,
Expand Down Expand Up @@ -584,7 +584,7 @@ def _finalize_response(self, updates: Sequence[AgentResponseUpdate]) -> AgentRes
return AgentResponse.from_updates(updates, value=structured_output)

@overload
def run(
def run( # type: ignore[override]
self,
messages: AgentRunInputs | None = None,
*,
Expand All @@ -595,7 +595,7 @@ def run(
) -> Awaitable[AgentResponse[Any]]: ...

@overload
def run(
def run( # type: ignore[override]
self,
messages: AgentRunInputs | None = None,
*,
Expand Down Expand Up @@ -747,3 +747,71 @@ class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[Options
response = await agent.run("Hello!")
print(response.text)
"""

@overload # type: ignore[override]
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
options: OptionsT | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
compaction_strategy: Any = None,
tokenizer: Any = None,
function_invocation_kwargs: dict[str, Any] | None = None,
client_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...

@overload # type: ignore[override]
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
options: OptionsT | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
compaction_strategy: Any = None,
tokenizer: Any = None,
function_invocation_kwargs: dict[str, Any] | None = None,
client_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...

def run( # pyright: ignore[reportIncompatibleMethodOverride] # type: ignore[override]
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
options: OptionsT | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
compaction_strategy: Any = None,
tokenizer: Any = None,
function_invocation_kwargs: dict[str, Any] | None = None,
client_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the Claude agent with telemetry enabled."""
super_run = cast(
"Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]",
super().run,
)
return super_run(
messages=messages,
stream=stream,
session=session,
middleware=middleware,
options=options,
tools=tools,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
Loading
Loading