Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "uipath-langchain"
version = "0.3.2"
version = "0.3.3"
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.11"
Expand Down
13 changes: 11 additions & 2 deletions src/uipath_langchain/agent/react/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Literal

from langchain_core.messages import AIMessage, AnyMessage, ToolCall
from langgraph.types import Send
from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL

from ..exceptions import AgentNodeRoutingException
Expand Down Expand Up @@ -59,7 +60,7 @@ def create_route_agent(thinking_messages_limit: int = 0):

def route_agent(
state: AgentGraphState,
) -> list[str] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]:
) -> list[str | Send] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]:
"""Route after agent: handles all routing logic including control flow detection.

Routing logic:
Expand All @@ -86,7 +87,15 @@ def route_agent(
return AgentGraphNode.TERMINATE

if tool_calls:
return [tc["name"] for tc in tool_calls]
return [
Send(
tc["name"],
AgentGraphState(
messages=messages, inner_state=state.inner_state, tool_call=tc
),
)
for tc in tool_calls
]

consecutive_thinking_messages = count_consecutive_thinking_messages(messages)

Expand Down
4 changes: 4 additions & 0 deletions src/uipath_langchain/agent/react/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Annotated, Any, Optional

from langchain_core.messages import AnyMessage
from langchain_core.messages.tool import ToolCall
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from uipath.platform.attachments import Attachment
Expand Down Expand Up @@ -33,6 +34,9 @@ class AgentGraphState(BaseModel):
inner_state: Annotated[InnerAgentGraphState, merge_objects] = Field(
default_factory=InnerAgentGraphState
)
tool_call: Optional[ToolCall] = (
None # This field is used to pass tool inputs to tool nodes.
)
Comment on lines +37 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. why is this added to the state?
  2. it should be added under internal_state

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to merge it back to the state, right?



class AgentGuardrailsGraphState(AgentGraphState):
Expand Down
25 changes: 9 additions & 16 deletions src/uipath_langchain/agent/tools/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from inspect import signature
from typing import Any, Awaitable, Callable, Literal

from langchain_core.messages.ai import AIMessage
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph._internal._runnable import RunnableCallable
from langgraph.types import Command
from pydantic import BaseModel

from ..react.types import AgentGraphState

# the type safety can be improved with generics
ToolWrapperType = Callable[
[BaseTool, ToolCall, Any], dict[str, Any] | Command[Any] | None
Expand Down Expand Up @@ -49,7 +50,9 @@ def __init__(
self.wrapper = wrapper
self.awrapper = awrapper

def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType:
def _func(
self, state: AgentGraphState, config: RunnableConfig | None = None
) -> OutputType:
call = self._extract_tool_call(state)
if call is None:
return None
Expand All @@ -61,7 +64,7 @@ def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType:
return self._process_result(call, result)

async def _afunc(
self, state: Any, config: RunnableConfig | None = None
self, state: AgentGraphState, config: RunnableConfig | None = None
Copy link
Contributor

@andreitava-uip andreitava-uip Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the entire state here, including the input args, for which there is the runtime dynamically created type CompleteAgentGraphState. Unfortunately since it's created at runtime we cannot annotate with it.

Generally for nodes langgraph validates the state using the annotation info, that being said, now that I think about it more, I think here it's always passing the entire state no matter what.

We should also probably configure AgentGraphState to allow extra fields on model validation... hmm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with Any. Both of them are fine because there's no type enforcement at runtime anyway.

) -> OutputType:
call = self._extract_tool_call(state)
if call is None:
Expand All @@ -73,20 +76,10 @@ async def _afunc(
result = await self.tool.ainvoke(call["args"])
return self._process_result(call, result)

def _extract_tool_call(self, state: Any) -> ToolCall | None:
def _extract_tool_call(self, state: AgentGraphState) -> ToolCall | None:
"""Extract the tool call from the state messages."""

if not hasattr(state, "messages"):
raise ValueError("State does not have messages key")

last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
raise ValueError("Last message in message stack is not an AIMessage.")

for tool_call in last_message.tool_calls:
if tool_call["name"] == self.tool.name:
return tool_call
return None
return state.tool_call

def _process_result(
self, call: ToolCall, result: dict[str, Any] | Command[Any] | None
Expand All @@ -101,7 +94,7 @@ def _process_result(
return {"messages": [message]}

def _filter_state(
self, state: Any, wrapper: ToolWrapperType | AsyncToolWrapperType
self, state: AgentGraphState, wrapper: ToolWrapperType | AsyncToolWrapperType
) -> BaseModel:
"""Filter the state to the expected model type."""
model_type = list(signature(wrapper).parameters.values())[2].annotation
Expand Down
25 changes: 3 additions & 22 deletions tests/agent/tools/test_tool_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for tool_node.py module."""

from typing import Any, Dict
from typing import Any, Dict, Optional

import pytest
from langchain_core.messages import AIMessage, HumanMessage
Expand Down Expand Up @@ -55,6 +55,7 @@ class MockState(BaseModel):
messages: list[Any] = []
user_id: str = "test_user"
session_id: str = "test_session"
tool_call: Optional[ToolCall] = None


def mock_wrapper(
Expand Down Expand Up @@ -101,7 +102,7 @@ def mock_state(self):
"id": "test_call_id",
}
ai_message = AIMessage(content="Using tool", tool_calls=[tool_call])
return MockState(messages=[ai_message])
return MockState(messages=[ai_message], tool_call=tool_call)

@pytest.fixture
def empty_state(self):
Expand Down Expand Up @@ -200,26 +201,6 @@ def test_no_tool_calls_returns_none(self, mock_tool, empty_state):

assert result is None

def test_non_ai_message_raises_error(self, mock_tool, non_ai_state):
"""Test that non-AI messages raise ValueError."""
node = UiPathToolNode(mock_tool)

with pytest.raises(
ValueError, match="Last message in message stack is not an AIMessage"
):
node._func(non_ai_state)

def test_mismatched_tool_name_returns_none(self, mock_tool, mock_state):
"""Test that mismatched tool names return None."""
# Change the tool call name to something different
mock_state.messages[-1].tool_calls[0]["name"] = "different_tool"

node = UiPathToolNode(mock_tool)

result = node._func(mock_state)

assert result is None

def test_state_filtering(self, mock_tool, mock_state):
"""Test that state is properly filtered for wrapper functions."""
node = UiPathToolNode(mock_tool, wrapper=mock_wrapper)
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.