Skip to content

Commit 91edca2

Browse files
feat: init context registry
1 parent 4f94efc commit 91edca2

9 files changed

Lines changed: 445 additions & 257 deletions

File tree

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Registry for resource types that contribute init-time context.
2+
3+
Resource modules self-register by calling ``register_init_context_provider``
4+
at module level. The INIT node calls ``gather_init_context`` to collect
5+
additional context from all registered providers, without needing to know
6+
which resource types participate.
7+
"""
8+
9+
import logging
10+
from typing import Protocol, Sequence
11+
12+
from uipath.agent.models.agent import BaseAgentResourceConfig
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class InitContextProvider(Protocol):
18+
"""Contract for a resource type's init-time context builder."""
19+
20+
async def __call__(
21+
self,
22+
resources: Sequence[BaseAgentResourceConfig],
23+
) -> str | None: ...
24+
25+
26+
_registry: dict[str, InitContextProvider] = {}
27+
28+
29+
def register_init_context_provider(
30+
name: str,
31+
provider: InitContextProvider,
32+
) -> None:
33+
"""Register a provider that contributes init-time context.
34+
35+
Args:
36+
name: Identifier for logging and deduplication.
37+
provider: Async callable matching ``InitContextProvider``.
38+
"""
39+
if name in _registry:
40+
raise ValueError(f"Init context provider '{name}' is already registered")
41+
_registry[name] = provider
42+
logger.debug("Registered init context provider: %s", name)
43+
44+
45+
async def gather_init_context(
46+
resources: Sequence[BaseAgentResourceConfig],
47+
) -> str | None:
48+
"""Call all registered providers and merge their context contributions.
49+
50+
Args:
51+
resources: The agent's resource configs.
52+
53+
Returns:
54+
Merged context string, or None if no provider contributed.
55+
"""
56+
parts: list[str] = []
57+
for name, provider in _registry.items():
58+
try:
59+
result = await provider(resources)
60+
if result:
61+
parts.append(result)
62+
logger.info(
63+
"Init context provider '%s' contributed %d chars",
64+
name,
65+
len(result),
66+
)
67+
except Exception:
68+
logger.exception("Init context provider '%s' failed; skipping", name)
69+
return "\n\n".join(parts) if parts else None

src/uipath_langchain/agent/react/init_node.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,19 @@ def create_init_node(
2424
resources_for_init: AgentResources | None = None,
2525
):
2626
async def graph_state_init(state: Any) -> Any:
27-
# --- Data Fabric schema fetch (INIT-time) ---
28-
schema_context: str | None = None
27+
# --- Gather init-time context from registered providers ---
28+
additional_context: str | None = None
2929
if resources_for_init:
30-
from uipath_langchain.agent.tools.datafabric_tool import (
31-
fetch_entity_schemas,
32-
format_schemas_for_context,
33-
get_datafabric_entity_identifiers_from_resources,
34-
)
30+
from .init_context_registry import gather_init_context
3531

36-
entity_identifiers = get_datafabric_entity_identifiers_from_resources(
37-
resources_for_init
38-
)
39-
if entity_identifiers:
40-
logger.info(
41-
"Fetching Data Fabric schemas for %d identifier(s)",
42-
len(entity_identifiers),
43-
)
44-
entities = await fetch_entity_schemas(entity_identifiers)
45-
schema_context = format_schemas_for_context(entities)
32+
additional_context = await gather_init_context(resources_for_init)
4633

4734
# --- Resolve messages ---
4835
resolved_messages: Sequence[SystemMessage | HumanMessage] | Overwrite
4936
if callable(messages):
50-
if schema_context:
37+
if additional_context:
5138
resolved_messages = list(
52-
messages(state, additional_context=schema_context)
39+
messages(state, additional_context=additional_context)
5340
)
5441
else:
5542
resolved_messages = list(messages(state))

src/uipath_langchain/agent/tools/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from .context_tool import create_context_tool
44
from .datafabric_tool import (
5-
create_datafabric_tools,
65
fetch_entity_schemas,
76
format_schemas_for_context,
87
get_datafabric_contexts,
@@ -23,7 +22,6 @@
2322
"create_tools_from_resources",
2423
"create_tool_node",
2524
"create_context_tool",
26-
"create_datafabric_tools",
2725
"open_mcp_tools",
2826
"create_process_tool",
2927
"create_integration_tool",

src/uipath_langchain/agent/tools/context_tool.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from uipath.agent.models.agent import (
1313
AgentContextResourceConfig,
1414
AgentContextRetrievalMode,
15+
AgentContextType,
1516
AgentToolArgumentArgumentProperties,
1617
AgentToolArgumentProperties,
1718
)
@@ -134,21 +135,25 @@ def is_static_query(resource: AgentContextResourceConfig) -> bool:
134135
return resource.settings.query.variant.lower() == "static"
135136

136137

137-
def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool:
138-
tool_name = sanitize_tool_name(resource.name)
138+
def create_context_tool(
139+
resource: AgentContextResourceConfig,
140+
) -> StructuredTool | BaseTool:
141+
if resource.context_type == AgentContextType.DATA_FABRIC_ENTITY_SET:
142+
from .datafabric_tool import create_datafabric_query_tool
143+
144+
return create_datafabric_query_tool()
145+
139146
assert resource.settings is not None
147+
tool_name = sanitize_tool_name(resource.name)
140148
retrieval_mode = resource.settings.retrieval_mode.lower()
149+
141150
if retrieval_mode == AgentContextRetrievalMode.DEEP_RAG.value.lower():
142151
return handle_deep_rag(tool_name, resource)
143-
elif retrieval_mode == AgentContextRetrievalMode.BATCH_TRANSFORM.value.lower():
152+
153+
if retrieval_mode == AgentContextRetrievalMode.BATCH_TRANSFORM.value.lower():
144154
return handle_batch_transform(tool_name, resource)
145-
elif retrieval_mode == AgentContextRetrievalMode.DATA_FABRIC.value.lower():
146-
raise ValueError(
147-
"Data Fabric context should be handled via create_datafabric_tools(), "
148-
"not create_context_tool()"
149-
)
150-
else:
151-
return handle_semantic_search(tool_name, resource)
155+
156+
return handle_semantic_search(tool_name, resource)
152157

153158

154159
def handle_semantic_search(
Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,50 @@
11
"""Data Fabric tool module for entity-based SQL queries."""
22

3+
import logging
4+
from typing import Sequence
5+
6+
from uipath.agent.models.agent import BaseAgentResourceConfig
7+
8+
from uipath_langchain.agent.react.init_context_registry import (
9+
register_init_context_provider,
10+
)
11+
312
from .datafabric_tool import (
4-
create_datafabric_tools,
13+
create_datafabric_query_tool,
514
fetch_entity_schemas,
6-
format_schemas_for_context,
715
get_datafabric_contexts,
816
get_datafabric_entity_identifiers_from_resources,
917
)
18+
from .schema_context import format_schemas_for_context
1019

1120
__all__ = [
12-
"create_datafabric_tools",
21+
"create_datafabric_query_tool",
1322
"fetch_entity_schemas",
1423
"format_schemas_for_context",
1524
"get_datafabric_contexts",
1625
"get_datafabric_entity_identifiers_from_resources",
1726
]
27+
28+
_logger = logging.getLogger(__name__)
29+
30+
31+
# --- Init-time context self-registration ---
32+
33+
34+
async def _datafabric_init_context_provider(
35+
resources: Sequence[BaseAgentResourceConfig],
36+
) -> str | None:
37+
"""Fetch and format DataFabric entity schemas for system prompt injection."""
38+
entity_identifiers = get_datafabric_entity_identifiers_from_resources(resources)
39+
if not entity_identifiers:
40+
return None
41+
42+
_logger.info(
43+
"Fetching Data Fabric schemas for %d identifier(s)",
44+
len(entity_identifiers),
45+
)
46+
entities = await fetch_entity_schemas(entity_identifiers)
47+
return format_schemas_for_context(entities)
48+
49+
50+
register_init_context_provider("datafabric", _datafabric_init_context_provider)

0 commit comments

Comments
 (0)