Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "uipath-langchain"
version = "0.3.2"
version = "0.3.3"
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.11"
Expand Down
77 changes: 24 additions & 53 deletions src/uipath_langchain/agent/guardrails/actions/filter_action.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Any

from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from uipath.core.guardrails.guardrails import FieldReference, FieldSource
from uipath.platform.guardrails import BaseGuardrail, GuardrailScope
Expand All @@ -11,6 +11,7 @@

from ...exceptions import AgentTerminationException
from ...react.types import AgentGuardrailsGraphState
from ...react.utils import extract_tool_call_from_state
from .base_action import GuardrailAction, GuardrailActionNode


Expand Down Expand Up @@ -143,66 +144,36 @@ def _filter_tool_input_fields(
if not has_input_fields:
return {}

msgs = state.messages.copy()
if not msgs:
return {}
tool_call_id = getattr(state, "tool_call_id", None)
tool_call, message = extract_tool_call_from_state(
state, tool_name, tool_call_id, return_message=True
)

# Find the AIMessage with tool calls
# At PRE_EXECUTION, this is always the last message
ai_message = None
for i in range(len(msgs) - 1, -1, -1):
msg = msgs[i]
if isinstance(msg, AIMessage) and msg.tool_calls:
ai_message = msg
break
if tool_call is None:
return {}

if ai_message is None:
args = tool_call["args"]
if not args or not isinstance(args, dict):
return {}

# Find and filter the tool call with matching name
# Type assertion: we know ai_message is AIMessage from the check above
assert isinstance(ai_message, AIMessage)
tool_calls = list(ai_message.tool_calls)
# Filter out the specified input fields
filtered_args = args.copy()
modified = False

for tool_call in tool_calls:
call_name = (
tool_call.get("name")
if isinstance(tool_call, dict)
else getattr(tool_call, "name", None)
)

if call_name == tool_name:
# Get the current args
args = (
tool_call.get("args")
if isinstance(tool_call, dict)
else getattr(tool_call, "args", None)
)
for field_ref in fields_to_filter:
# Only filter input fields
if field_ref.source == FieldSource.INPUT and field_ref.path in filtered_args:
del filtered_args[field_ref.path]
modified = True

if args and isinstance(args, dict):
# Filter out the specified input fields
filtered_args = args.copy()
for field_ref in fields_to_filter:
# Only filter input fields
if (
field_ref.source == FieldSource.INPUT
and field_ref.path in filtered_args
):
del filtered_args[field_ref.path]
modified = True

# Update the tool call with filtered args
if isinstance(tool_call, dict):
tool_call["args"] = filtered_args
else:
tool_call.args = filtered_args

break
if modified:
tool_call["args"] = filtered_args
message.tool_calls = [
tool_call if tool_call["id"] == tc["id"] else tc
for tc in message.tool_calls
]

if modified:
ai_message.tool_calls = tool_calls
return Command(update={"messages": msgs})
return Command(update={"messages": [message]})

return {}

Expand Down
12 changes: 11 additions & 1 deletion src/uipath_langchain/agent/react/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from uipath.platform.guardrails import BaseGuardrail

from ..guardrails.actions import GuardrailAction
from .aggregator_node import (
create_aggregator_node,
)
from .guardrails.guardrails_subgraph import (
create_agent_init_guardrails_subgraph,
create_agent_terminate_guardrails_subgraph,
Expand Down Expand Up @@ -107,6 +110,10 @@ def create_agent(
)
builder.add_node(AgentGraphNode.TERMINATE, terminate_with_guardrails_subgraph)

# Add aggregator node
aggregator_node = create_aggregator_node()
builder.add_node(AgentGraphNode.AGGREGATOR, aggregator_node)

builder.add_edge(START, AgentGraphNode.INIT)

llm_node = create_llm_node(model, llm_tools, config.thinking_messages_limit)
Expand All @@ -125,7 +132,10 @@ def create_agent(
)

for tool_name in tool_node_names:
builder.add_edge(tool_name, AgentGraphNode.AGENT)
builder.add_edge(tool_name, AgentGraphNode.AGGREGATOR)

# Aggregator goes back to agent
builder.add_edge(AgentGraphNode.AGGREGATOR, AgentGraphNode.AGENT)

builder.add_edge(AgentGraphNode.TERMINATE, END)

Expand Down
125 changes: 125 additions & 0 deletions src/uipath_langchain/agent/react/aggregator_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Aggregator node for merging substates back into main state."""

from typing import Any

from langchain_core.messages import AIMessage, AnyMessage
from langgraph.types import Overwrite

from uipath_langchain.agent.react.types import AgentGraphState, InnerAgentGraphState


def _aggregate_messages(
original_messages: list[AnyMessage], substate_messages: dict[str, list[AnyMessage]]
) -> list[AnyMessage]:
aggregated_by_id: dict[str, AnyMessage] = {}
original_order: list[str] = []

for msg in original_messages:
aggregated_by_id[msg.id] = msg
original_order.append(msg.id)

new_messages: list[AnyMessage] = []

for tool_call_id, substate_msgs in substate_messages.items():
for msg in substate_msgs:
if msg.id in aggregated_by_id:
# existing message
original_msg = aggregated_by_id[msg.id]
if (
isinstance(msg, AIMessage)
and msg.tool_calls
and len(msg.tool_calls) > 0
):
updated_tool_call = next(
(tc for tc in msg.tool_calls if tc["id"] == tool_call_id), None
)
if updated_tool_call:
# update the specific tool call in the original message
new_tool_calls = [
updated_tool_call if tc["id"] == tool_call_id else tc
for tc in original_msg.tool_calls
]
aggregated_by_id[msg.id].tool_calls = new_tool_calls
else:
# new message, add it
new_messages.append(msg)

result = []
for msg_id in original_order:
result.append(aggregated_by_id[msg_id])
result.extend(new_messages)

return result


def create_aggregator_node() -> callable:
"""Create an aggregator node that merges substates back into main state."""

def aggregator_node(state: AgentGraphState) -> dict[str, Any] | Overwrite:
"""
Aggregate substates back into main state.

If substates is empty, no-op and continue.
If substates is non-empty:
- for messages, leave placeholder for message aggregation logic
- for each field in inner state, get its reducer and apply updates
- lastly, overwrite the state and clear substates
"""
if not state.substates:
return {}

# message aggregation
substate_messages = {}
for tool_call_id, substate in state.substates.items():
if "messages" in substate:
substate_messages[tool_call_id] = substate["messages"]

aggregated_messages = _aggregate_messages(state.messages, substate_messages)

# inner state fields aggregation
aggregated_inner_dict = state.inner_state.model_dump()

inner_state_fields = InnerAgentGraphState.model_fields
for substate in state.substates.values():
if "inner_state" in substate:
substate_inner_data = substate["inner_state"]

if isinstance(substate_inner_data, InnerAgentGraphState):
substate_inner_dict = substate_inner_data.model_dump()
else:
substate_inner_dict = substate_inner_data

# for each field, apply reducer if defined
for field_name, field_info in inner_state_fields.items():
if field_name in substate_inner_dict:
substate_field_value = substate_inner_dict[field_name]
current_field_value = aggregated_inner_dict[field_name]

if field_info.metadata and callable(field_info.metadata[-1]):
reducer_func = field_info.metadata[-1]
merged_value = reducer_func(
current_field_value, substate_field_value
)
else:
# no reducer, just replace
merged_value = substate_field_value

aggregated_inner_dict[field_name] = merged_value

aggregated_inner_state = InnerAgentGraphState.model_validate(
aggregated_inner_dict
)

state.messages = aggregated_messages
state.inner_state = aggregated_inner_state
state.substates = {}

# return overwrite command to replace the state
return {
**state.model_dump(exclude={"messages", "inner_state", "substates"}),
"messages": Overwrite(aggregated_messages),
"inner_state": Overwrite(aggregated_inner_state),
"substates": Overwrite({}),
}

return aggregator_node
30 changes: 27 additions & 3 deletions src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
)
from uipath_langchain.agent.guardrails.types import ExecutionStage
from uipath_langchain.agent.react.types import (
AgentGraphNode,
AgentGraphState,
AgentGuardrailsGraphState,
SubgraphOutputModel,
)

_VALIDATOR_ALLOWED_STAGES = {
Expand All @@ -33,6 +35,21 @@
}


def _tool_call_state_handler(state: AgentGuardrailsGraphState) -> dict[str, Any]:
"""Handle tool call state by moving contents to substates if tool_call is present."""
if state.tool_call_id is not None:
# Move current state contents to substates under tool_call_id
return {
"substates": {
state.tool_call_id: {
"messages": state.messages,
"inner_state": state.inner_state,
}
}
}
return {}


def _filter_guardrails_by_stage(
guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
stage: ExecutionStage,
Expand Down Expand Up @@ -83,7 +100,7 @@ def _create_guardrails_subgraph(
"""
inner_name, inner_node = main_inner_node

subgraph = StateGraph(AgentGuardrailsGraphState)
subgraph = StateGraph(AgentGuardrailsGraphState, output_schema=SubgraphOutputModel)

subgraph.add_node(inner_name, inner_node)

Expand All @@ -105,6 +122,10 @@ def _create_guardrails_subgraph(
else:
subgraph.add_edge(START, inner_name)

# Always add the tool call state handler node at the end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is generic for any scope (agent, llm or tool). Won't your changes add the TOOL_CALL_STATE_HANDLER for agent and llm guardrail subgraphs as well?

Copy link
Contributor Author

@andreitava-uip andreitava-uip Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would, but the llm and agent subgraphs are expected to have tool_call_id=None, which will make the node a no-op.
We can of course also conditionally add this node only for tool guardrails, that would make it more optimized.

tool_call_handler_name = AgentGraphNode.TOOL_CALL_STATE_HANDLER
subgraph.add_node(tool_call_handler_name, _tool_call_state_handler)

# Add post execution guardrail nodes
if ExecutionStage.POST_EXECUTION in execution_stages:
post_guardrails = _filter_guardrails_by_stage(
Expand All @@ -116,12 +137,15 @@ def _create_guardrails_subgraph(
scope,
ExecutionStage.POST_EXECUTION,
node_factory,
END,
tool_call_handler_name,
inner_name,
)
subgraph.add_edge(inner_name, first_post_exec_guardrail_node)
else:
subgraph.add_edge(inner_name, END)
subgraph.add_edge(inner_name, tool_call_handler_name)

# Always connect tool call handler to END
subgraph.add_edge(tool_call_handler_name, END)

return subgraph.compile()

Expand Down
13 changes: 10 additions & 3 deletions src/uipath_langchain/agent/react/llm_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""LLM node for ReAct Agent graph."""

from typing import Literal, Sequence
from typing import Any, Literal, Sequence

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage
Expand Down Expand Up @@ -53,8 +53,15 @@ def create_llm_node(
base_llm = model.bind_tools(bindable_tools) if bindable_tools else model
tool_choice_required_value = _get_required_tool_choice_by_model(model)

async def llm_node(state: AgentGraphState):
messages: list[AnyMessage] = state.messages
async def llm_node(state: Any):
# we need to use Any here because LangGraph has weird edge behavior
# if the type annotation for the state in the edge function is Any/BaseModel/dict/etc aka not a specific model
# then LangGraph will pass the **same** state that was passed to the previous node
# meaning if we want the full state in the edge, we need to pass the full state here as well
# unfortunately, using AgentGraphState in the annotation and relying on extra="allow" does not work
# so we are doing the validation manually here
agent_state = AgentGraphState.model_validate(state, from_attributes=True)
messages: list[AnyMessage] = agent_state.messages

consecutive_thinking_messages = count_consecutive_thinking_messages(messages)

Expand Down
Loading
Loading