Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/google/adk/utils/context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Callable
from typing import get_args
from typing import get_origin
from typing import get_type_hints
from typing import Union

# Re-export aclosing for backward compatibility
Expand Down Expand Up @@ -80,7 +81,12 @@ def find_context_parameter(func: Callable[..., Any]) -> str | None:
signature = inspect.signature(func)
except (ValueError, TypeError):
return None
try:
hints = get_type_hints(func)
except (NameError, TypeError, AttributeError):
hints = {}
for name, param in signature.parameters.items():
if _is_context_type(param.annotation):
annotation = hints.get(name, param.annotation)
if _is_context_type(annotation):
return name
return None
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@

from typing import Any
from typing import Dict
from typing import Optional

from google.adk.tools import _automatic_function_calling_util
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
from google.adk.utils.context_utils import find_context_parameter
from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types

Expand Down Expand Up @@ -177,3 +181,41 @@ def test_function() -> str:
# VERTEX_AI should have response schema for string return (stored as string)
assert declaration.response is not None
assert declaration.response.type == types.Type.STRING


def test_find_context_parameter_detects_custom_name_with_future_annotations():
"""Test custom param name detection with deferred annotations."""

# pylint: disable-next=unused-argument
def my_tool(ctx: ToolContext) -> str:
"""A tool with a custom-named context parameter."""
return ''

assert find_context_parameter(my_tool) == 'ctx'


def test_find_context_parameter_optional_ctx_with_future_annotations():
"""Test Optional[ToolContext] detection with deferred annotations."""

# pylint: disable-next=unused-argument
def my_tool(ctx: Optional[ToolContext] = None) -> str:
"""A tool with an optional context parameter."""
return ''

assert find_context_parameter(my_tool) == 'ctx'


def test_function_tool_excludes_custom_context_param_with_future_annotations():
"""Test schema omits context param with deferred annotations."""

# pylint: disable-next=unused-argument
def my_tool(query: str, ctx: ToolContext) -> str:
"""A tool with a custom-named context parameter."""
return query

tool = FunctionTool(my_tool)
declaration = tool._get_declaration() # pylint: disable=protected-access

assert declaration.name == 'my_tool'
assert set(declaration.parameters.properties.keys()) == {'query'}
assert 'ctx' not in declaration.parameters.properties