generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 647
feat(hooks): Add hook decorator #1581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mkmeral
wants to merge
12
commits into
strands-agents:main
Choose a base branch
from
mkmeral:fix/hook-decorator-review-fixes
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+840
−9
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
3eb4f65
feat(hooks): add @hook decorator for simplified hook definitions
strands-agent 80b4b04
fix(hooks): fix mypy type errors for hook decorator
strands-agent ca704f3
feat(hooks): add automatic agent injection to @hook decorator
strands-agent ca56c91
test(hooks): add comprehensive tests for @hook decorator coverage
strands-agent 70689a0
fix(hooks): address review comments for @hook decorator
9013876
docs: simplify docstrings, remove implementation details
a07be6a
refactor: remove unused logger, simplify class docstring
733156f
refactor: remove agent injection, simplify @hook decorator
b576d43
docs: update PR description
20643c7
Merge branch 'main' into fix/hook-decorator-review-fixes
mkmeral 89ed654
chore: delete description
mkmeral 3de3688
refactor(hooks): simplify @hook to type-hints only, fix public API
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,289 @@ | ||
| """Hook decorator for defining hooks as functions. | ||
|
|
||
| This module provides the @hook decorator that transforms Python functions into | ||
| HookProvider implementations with automatic event type detection from type hints. | ||
|
|
||
| Example: | ||
| ```python | ||
| from strands import Agent, hook | ||
| from strands.hooks import BeforeToolCallEvent | ||
|
|
||
| @hook | ||
| def log_tool_calls(event: BeforeToolCallEvent) -> None: | ||
| '''Log all tool calls before execution.''' | ||
| print(f"Tool: {event.tool_use}") | ||
|
|
||
| agent = Agent(hooks=[log_tool_calls]) | ||
| ``` | ||
| """ | ||
|
|
||
| import functools | ||
| import inspect | ||
| import types | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass | ||
| from typing import ( | ||
| Any, | ||
| Generic, | ||
| TypeVar, | ||
| Union, | ||
| cast, | ||
| get_args, | ||
| get_origin, | ||
| get_type_hints, | ||
| ) | ||
|
|
||
| from .registry import BaseHookEvent, HookCallback, HookProvider, HookRegistry | ||
|
|
||
| TEvent = TypeVar("TEvent", bound=BaseHookEvent) | ||
|
|
||
|
|
||
| @dataclass | ||
| class HookMetadata: | ||
| """Metadata extracted from a decorated hook function. | ||
|
|
||
| Attributes: | ||
| name: The name of the hook function. | ||
| description: Description extracted from the function's docstring. | ||
| event_types: List of event types this hook handles. | ||
| is_async: Whether the hook function is async. | ||
| """ | ||
|
|
||
| name: str | ||
| description: str | ||
| event_types: list[type[BaseHookEvent]] | ||
| is_async: bool | ||
|
|
||
|
|
||
| class FunctionHookMetadata: | ||
| """Helper class to extract and manage function metadata for hook decoration.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| func: Callable[..., Any], | ||
| ) -> None: | ||
| """Initialize with the function to process. | ||
|
|
||
| Args: | ||
| func: The function to extract metadata from. | ||
| """ | ||
| self.func = func | ||
| self.signature = inspect.signature(func) | ||
|
|
||
| # Validate and extract event types | ||
| self._event_types = self._resolve_event_types() | ||
| self._validate_event_types() | ||
|
|
||
| def _resolve_event_types(self) -> list[type[BaseHookEvent]]: | ||
| """Resolve event types from type hints. | ||
|
|
||
| Returns: | ||
| List of event types this hook handles. | ||
|
|
||
| Raises: | ||
| ValueError: If no event type can be determined. | ||
| """ | ||
| # Try to extract from type hints | ||
| try: | ||
| type_hints = get_type_hints(self.func) | ||
| except Exception: | ||
| # get_type_hints can fail for various reasons (forward refs, etc.) | ||
| type_hints = {} | ||
|
|
||
| # Find the first parameter's type hint (should be the event) | ||
| # Skip 'self' and 'cls' for class methods | ||
| params = list(self.signature.parameters.values()) | ||
| event_params = [p for p in params if p.name not in ("self", "cls")] | ||
|
|
||
| if not event_params: | ||
| raise ValueError( | ||
| f"Hook function '{self.func.__name__}' must have at least one parameter " | ||
| "for the event with a type hint." | ||
| ) | ||
|
|
||
| first_param = event_params[0] | ||
| event_type = type_hints.get(first_param.name) | ||
|
|
||
| if event_type is None: | ||
| # Check annotation directly (for cases where get_type_hints fails) | ||
| if first_param.annotation is not inspect.Parameter.empty: | ||
| event_type = first_param.annotation | ||
| else: | ||
| raise ValueError( | ||
| f"Hook function '{self.func.__name__}' must have a type hint for the event parameter." | ||
| ) | ||
|
|
||
| # Handle Union types (e.g., BeforeToolCallEvent | AfterToolCallEvent) | ||
| return self._extract_event_types_from_annotation(event_type) | ||
|
|
||
| def _is_union_type(self, annotation: Any) -> bool: | ||
| """Check if annotation is a Union type (typing.Union or types.UnionType).""" | ||
| origin = get_origin(annotation) | ||
| if origin is Union: | ||
| return True | ||
|
|
||
| # Python 3.10+ uses types.UnionType for `A | B` syntax | ||
| if isinstance(annotation, types.UnionType): | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
| def _extract_event_types_from_annotation(self, annotation: Any) -> list[type[BaseHookEvent]]: | ||
| """Extract event types from a type annotation.""" | ||
| # Handle Union types (Union[A, B] or A | B) | ||
| if self._is_union_type(annotation): | ||
| args = get_args(annotation) | ||
| event_types = [] | ||
| for arg in args: | ||
| # Skip NoneType in Optional[X] | ||
| if arg is type(None): | ||
| continue | ||
| if isinstance(arg, type) and issubclass(arg, BaseHookEvent): | ||
| event_types.append(arg) | ||
| else: | ||
| raise ValueError(f"All types in Union must be subclasses of BaseHookEvent, got {arg}") | ||
| return event_types | ||
|
|
||
| # Single type | ||
| if isinstance(annotation, type) and issubclass(annotation, BaseHookEvent): | ||
| return [annotation] | ||
|
|
||
| raise ValueError(f"Event type must be a subclass of BaseHookEvent, got {annotation}") | ||
|
|
||
| def _validate_event_types(self) -> None: | ||
| """Validate that all event types are valid.""" | ||
| if not self._event_types: | ||
| raise ValueError(f"Hook function '{self.func.__name__}' must handle at least one event type.") | ||
|
|
||
| for event_type in self._event_types: | ||
| if not isinstance(event_type, type) or not issubclass(event_type, BaseHookEvent): | ||
| raise ValueError(f"Event type must be a subclass of BaseHookEvent, got {event_type}") | ||
|
|
||
| def extract_metadata(self) -> HookMetadata: | ||
| """Extract metadata from the function to create hook specification.""" | ||
| return HookMetadata( | ||
| name=self.func.__name__, | ||
| description=inspect.getdoc(self.func) or self.func.__name__, | ||
| event_types=self._event_types, | ||
| is_async=inspect.iscoroutinefunction(self.func), | ||
| ) | ||
|
|
||
| @property | ||
| def event_types(self) -> list[type[BaseHookEvent]]: | ||
| """Get the event types this hook handles.""" | ||
| return self._event_types | ||
|
|
||
|
|
||
| class DecoratedFunctionHook(HookProvider, Generic[TEvent]): | ||
| """A HookProvider that wraps a function decorated with @hook.""" | ||
|
|
||
| _func: Callable[[TEvent], Any] | ||
| _metadata: FunctionHookMetadata | ||
| _hook_metadata: HookMetadata | ||
|
|
||
| def __init__( | ||
| self, | ||
| func: Callable[[TEvent], Any], | ||
| metadata: FunctionHookMetadata, | ||
| ): | ||
| """Initialize the decorated function hook. | ||
|
|
||
| Args: | ||
| func: The original function being decorated. | ||
| metadata: The FunctionHookMetadata object with extracted function information. | ||
| """ | ||
| self._func = func | ||
| self._metadata = metadata | ||
| self._hook_metadata = metadata.extract_metadata() | ||
|
|
||
| # Preserve function metadata | ||
| functools.update_wrapper(wrapper=self, wrapped=self._func) | ||
|
|
||
| def __get__(self, instance: Any, obj_type: type[Any] | None = None) -> "DecoratedFunctionHook[TEvent]": | ||
| """Descriptor protocol implementation for proper method binding.""" | ||
| if instance is not None and not inspect.ismethod(self._func): | ||
| # Create a bound method | ||
| bound_func = self._func.__get__(instance, instance.__class__) | ||
| return DecoratedFunctionHook(bound_func, self._metadata) | ||
|
|
||
| return self | ||
|
|
||
| def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: | ||
| """Register callback functions for specific event types.""" | ||
| callback = cast(HookCallback[BaseHookEvent], self._func) | ||
| for event_type in self._metadata.event_types: | ||
| registry.add_callback(event_type, callback) | ||
|
|
||
| def __call__(self, event: TEvent) -> Any: | ||
| """Allow direct invocation for testing.""" | ||
| return self._func(event) | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| """Get the name of the hook.""" | ||
| return self._hook_metadata.name | ||
|
|
||
| @property | ||
| def description(self) -> str: | ||
| """Get the description of the hook.""" | ||
| return self._hook_metadata.description | ||
|
|
||
| @property | ||
| def event_types(self) -> list[type[BaseHookEvent]]: | ||
| """Get the event types this hook handles.""" | ||
| return self._hook_metadata.event_types | ||
|
|
||
| @property | ||
| def is_async(self) -> bool: | ||
| """Check if this hook is async.""" | ||
| return self._hook_metadata.is_async | ||
|
|
||
| def __repr__(self) -> str: | ||
| """Return a string representation of the hook.""" | ||
| event_names = [e.__name__ for e in self._hook_metadata.event_types] | ||
| return f"DecoratedFunctionHook({self._hook_metadata.name}, events={event_names})" | ||
|
|
||
|
|
||
| # Type variable for the decorated function | ||
| F = TypeVar("F", bound=Callable[..., Any]) | ||
|
|
||
|
|
||
| def hook( | ||
| func: F | None = None, | ||
| ) -> DecoratedFunctionHook[Any] | Callable[[F], DecoratedFunctionHook[Any]]: | ||
| """Decorator that transforms a function into a HookProvider. | ||
|
|
||
| The decorated function can be passed directly to Agent(hooks=[...]). | ||
| Event types are automatically detected from the function's type hints. | ||
|
|
||
| Args: | ||
| func: The function to decorate. | ||
|
|
||
| Returns: | ||
| A DecoratedFunctionHook that implements HookProvider. | ||
|
|
||
| Raises: | ||
| ValueError: If no event type can be determined from type hints. | ||
| ValueError: If event types are not subclasses of BaseHookEvent. | ||
|
|
||
| Example: | ||
| ```python | ||
| from strands import Agent, hook | ||
| from strands.hooks import BeforeToolCallEvent | ||
|
|
||
| @hook | ||
| def log_tool_calls(event: BeforeToolCallEvent) -> None: | ||
| print(f"Tool: {event.tool_use}") | ||
|
|
||
| agent = Agent(hooks=[log_tool_calls]) | ||
| ``` | ||
| """ | ||
|
|
||
| def decorator(f: F) -> DecoratedFunctionHook[Any]: | ||
| hook_meta = FunctionHookMetadata(f) | ||
| return DecoratedFunctionHook(f, hook_meta) | ||
|
|
||
| if func is None: | ||
| return decorator | ||
|
|
||
| return decorator(func) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Important: The
__get__descriptor creates a newDecoratedFunctionHookwith the same_metadatainstance, but that metadata was created from the original unbound function.When
self._metadatais shared between the class-level hook and instance-bound hooks, the event types and other metadata remain correct. However, theself._metadata.funcstill references the original unbound function, not the bound method.This works because
self._func(which is the bound method) is used for execution, but could be confusing during debugging. Consider whether the newDecoratedFunctionHookshould create its ownFunctionHookMetadatawith the bound function, or if this is an acceptable tradeoff for simplicity.