-
Notifications
You must be signed in to change notification settings - Fork 29
fix(ToolNode): pass correct inputs to tool nodes #389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,14 +4,15 @@ | |
| from inspect import signature | ||
| from typing import Any, Awaitable, Callable, Literal | ||
|
|
||
| from langchain_core.messages.ai import AIMessage | ||
| from langchain_core.messages.tool import ToolCall, ToolMessage | ||
| from langchain_core.runnables.config import RunnableConfig | ||
| from langchain_core.tools import BaseTool | ||
| from langgraph._internal._runnable import RunnableCallable | ||
| from langgraph.types import Command | ||
| from pydantic import BaseModel | ||
|
|
||
| from ..react.types import AgentGraphState | ||
|
|
||
| # the type safety can be improved with generics | ||
| ToolWrapperType = Callable[ | ||
| [BaseTool, ToolCall, Any], dict[str, Any] | Command[Any] | None | ||
|
|
@@ -49,7 +50,9 @@ def __init__( | |
| self.wrapper = wrapper | ||
| self.awrapper = awrapper | ||
|
|
||
| def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType: | ||
| def _func( | ||
| self, state: AgentGraphState, config: RunnableConfig | None = None | ||
| ) -> OutputType: | ||
| call = self._extract_tool_call(state) | ||
| if call is None: | ||
| return None | ||
|
|
@@ -61,7 +64,7 @@ def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType: | |
| return self._process_result(call, result) | ||
|
|
||
| async def _afunc( | ||
| self, state: Any, config: RunnableConfig | None = None | ||
| self, state: AgentGraphState, config: RunnableConfig | None = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need the entire state here, including the input args, for which there is the runtime dynamically created type Generally for nodes langgraph validates the state using the annotation info, that being said, now that I think about it more, I think here it's always passing the entire state no matter what. We should also probably configure AgentGraphState to allow extra fields on model validation... hmm
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm OK with |
||
| ) -> OutputType: | ||
| call = self._extract_tool_call(state) | ||
| if call is None: | ||
|
|
@@ -73,20 +76,10 @@ async def _afunc( | |
| result = await self.tool.ainvoke(call["args"]) | ||
| return self._process_result(call, result) | ||
|
|
||
| def _extract_tool_call(self, state: Any) -> ToolCall | None: | ||
| def _extract_tool_call(self, state: AgentGraphState) -> ToolCall | None: | ||
| """Extract the tool call from the state messages.""" | ||
|
|
||
| if not hasattr(state, "messages"): | ||
| raise ValueError("State does not have messages key") | ||
|
|
||
| last_message = state.messages[-1] | ||
| if not isinstance(last_message, AIMessage): | ||
| raise ValueError("Last message in message stack is not an AIMessage.") | ||
|
|
||
| for tool_call in last_message.tool_calls: | ||
| if tool_call["name"] == self.tool.name: | ||
| return tool_call | ||
| return None | ||
| return state.tool_call | ||
|
|
||
| def _process_result( | ||
| self, call: ToolCall, result: dict[str, Any] | Command[Any] | None | ||
|
|
@@ -101,7 +94,7 @@ def _process_result( | |
| return {"messages": [message]} | ||
|
|
||
| def _filter_state( | ||
| self, state: Any, wrapper: ToolWrapperType | AsyncToolWrapperType | ||
| self, state: AgentGraphState, wrapper: ToolWrapperType | AsyncToolWrapperType | ||
| ) -> BaseModel: | ||
| """Filter the state to the expected model type.""" | ||
| model_type = list(signature(wrapper).parameters.values())[2].annotation | ||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to merge it back to the state, right?