Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 177 additions & 64 deletions src/uipath_langchain/agent/tools/context_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Context tool creation for semantic index retrieval."""

import uuid
from typing import Any
from typing import Any, Optional, Type

from langchain_core.documents import Document
from langchain_core.tools import StructuredTool
Expand All @@ -26,6 +26,13 @@
from .utils import sanitize_tool_name


def is_static_query(resource: AgentContextResourceConfig) -> bool:
"""Check if the resource configuration uses a static query variant."""
if resource.settings.query is None or resource.settings.query.variant is None:
return False
return resource.settings.query.variant.lower() == "static"


def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool:
tool_name = sanitize_tool_name(resource.name)
retrieval_mode = resource.settings.retrieval_mode.lower()
Expand All @@ -40,34 +47,58 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool:
def handle_semantic_search(
tool_name: str, resource: AgentContextResourceConfig
) -> StructuredTool:
ensure_valid_fields(resource)

# needed for type checking
assert resource.settings.query is not None
assert resource.settings.query.variant is not None

retriever = ContextGroundingRetriever(
index_name=resource.index_name,
folder_path=resource.folder_path,
number_of_results=resource.settings.result_count,
)

class ContextInputSchemaModel(BaseModel):
query: str = Field(
..., description="The query to search for in the knowledge base"
)

class ContextOutputSchemaModel(BaseModel):
documents: list[Document] = Field(
..., description="List of retrieved documents."
)

input_model = ContextInputSchemaModel
output_model = ContextOutputSchemaModel

@mockable(
name=resource.name,
description=resource.description,
input_schema=input_model.model_json_schema(),
output_schema=output_model.model_json_schema(),
example_calls=[], # Examples cannot be provided for context.
)
async def context_tool_fn(query: str) -> dict[str, Any]:
return {"documents": await retriever.ainvoke(query)}
if is_static_query(resource):
static_query_value = resource.settings.query.value
assert static_query_value is not None
input_model = None

@mockable(
name=resource.name,
description=resource.description,
input_schema=input_model,
output_schema=output_model.model_json_schema(),
example_calls=[], # Examples cannot be provided for context.
)
async def context_tool_fn() -> dict[str, Any]:
return {"documents": await retriever.ainvoke(static_query_value)}

else:
# Dynamic query - requires query parameter
class ContextInputSchemaModel(BaseModel):
query: str = Field(
..., description="The query to search for in the knowledge base"
)

input_model = ContextInputSchemaModel

@mockable(
name=resource.name,
description=resource.description,
input_schema=input_model.model_json_schema(),
output_schema=output_model.model_json_schema(),
example_calls=[], # Examples cannot be provided for context.
)
async def context_tool_fn(query: str) -> dict[str, Any]:
return {"documents": await retriever.ainvoke(query)}

return StructuredToolWithOutputType(
name=tool_name,
Expand All @@ -82,36 +113,69 @@ def handle_deep_rag(
tool_name: str, resource: AgentContextResourceConfig
) -> StructuredTool:
ensure_valid_fields(resource)

# needed for type checking
assert resource.settings.query is not None
assert resource.settings.query.value is not None
assert resource.settings.query.variant is not None

index_name = resource.index_name
prompt = resource.settings.query.value
if not resource.settings.citation_mode:
raise ValueError("Citation mode is required for Deep RAG")
citation_mode = CitationMode(resource.settings.citation_mode.value)

input_model = None
output_model = DeepRagResponse

@mockable(
name=resource.name,
description=resource.description,
input_schema=input_model,
output_schema=output_model.model_json_schema(),
example_calls=[], # Examples cannot be provided for context.
)
async def context_tool_fn() -> dict[str, Any]:
# TODO: add glob pattern support
return interrupt(
CreateDeepRag(
name=f"task-{uuid.uuid4()}",
index_name=index_name,
prompt=prompt,
citation_mode=citation_mode,
if is_static_query(resource):
# Static query - no input parameter needed
static_prompt = resource.settings.query.value
assert static_prompt is not None
input_model = None

@mockable(
name=resource.name,
description=resource.description,
input_schema=input_model,
output_schema=output_model.model_json_schema(),
example_calls=[], # Examples cannot be provided for context.
)
async def context_tool_fn() -> dict[str, Any]:
# TODO: add glob pattern support
return interrupt(
CreateDeepRag(
name=f"task-{uuid.uuid4()}",
index_name=index_name,
prompt=static_prompt,
citation_mode=citation_mode,
)
)

else:
# Dynamic query - requires query parameter
class DeepRagInputSchemaModel(BaseModel):
query: str = Field(
...,
description="Describe the task: what to research across documents, what to synthesize, and how to cite sources",
)

input_model = DeepRagInputSchemaModel

@mockable(
name=resource.name,
description=resource.description,
input_schema=input_model.model_json_schema(),
output_schema=output_model.model_json_schema(),
example_calls=[], # Examples cannot be provided for context.
)
async def context_tool_fn(query: str) -> dict[str, Any]:
# TODO: add glob pattern support
return interrupt(
CreateDeepRag(
name=f"task-{uuid.uuid4()}",
index_name=index_name,
prompt=query,
citation_mode=citation_mode,
)
)

return StructuredToolWithOutputType(
name=tool_name,
Expand All @@ -129,11 +193,9 @@ def handle_batch_transform(

# needed for type checking
assert resource.settings.query is not None
assert resource.settings.query.value is not None
assert resource.settings.query.variant is not None

index_name = resource.index_name
prompt = resource.settings.query.value

index_folder_path = resource.folder_path
if not resource.settings.web_search_grounding:
raise ValueError("Web search grounding field is required for Batch Transform")
Expand All @@ -157,35 +219,82 @@ def handle_batch_transform(
)
)

class BatchTransformSchemaModel(BaseModel):
destination_path: str = Field(
...,
description="The relative file path destination for the modified csv file",
)

input_model = BatchTransformSchemaModel
output_model = BatchTransformResponse

@mockable(
name=resource.name,
description=resource.description,
input_schema=input_model.model_json_schema(),
output_schema=output_model.model_json_schema(),
example_calls=[], # Examples cannot be provided for context.
)
async def context_tool_fn(destination_path: str) -> dict[str, Any]:
# TODO: storage_bucket_folder_path_prefix support
return interrupt(
CreateBatchTransform(
name=f"task-{uuid.uuid4()}",
index_name=index_name,
prompt=prompt,
destination_path=destination_path,
index_folder_path=index_folder_path,
enable_web_search_grounding=enable_web_search_grounding,
output_columns=batch_transform_output_columns,
input_model: Optional[Type[BaseModel]]

if is_static_query(resource):
# Static query - only destination_path parameter needed
static_prompt = resource.settings.query.value
assert static_prompt is not None

class StaticBatchTransformSchemaModel(BaseModel):
destination_path: str = Field(
default="output.csv",
description="The relative file path destination for the modified csv file",
)

input_model = StaticBatchTransformSchemaModel

@mockable(
name=resource.name,
description=resource.description,
input_schema=input_model.model_json_schema(),
output_schema=output_model.model_json_schema(),
example_calls=[], # Examples cannot be provided for context.
)
async def context_tool_fn(
destination_path: str = "output.csv",
) -> dict[str, Any]:
# TODO: storage_bucket_folder_path_prefix support
return interrupt(
CreateBatchTransform(
name=f"task-{uuid.uuid4()}",
index_name=index_name,
prompt=static_prompt,
destination_path=destination_path,
index_folder_path=index_folder_path,
enable_web_search_grounding=enable_web_search_grounding,
output_columns=batch_transform_output_columns,
)
)

else:
# Dynamic query - requires both query and destination_path parameters
class DynamicBatchTransformSchemaModel(BaseModel):
query: str = Field(
...,
description="Describe the task for each row: what to analyze, what to extract, and how to populate the output columns",
)
destination_path: str = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this going to be configurable from UI? if not we should have a default value here. the LLM might not know what destination_path to provide and refuse to process the request

Copy link
Collaborator

@radu-mocanu radu-mocanu Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually encountered this earlier.
image
Managed to fix it with some prompt engineering but we should never run into this if we wish to maintain parity with low code (where destination path is not configurable).

If there is not going to be an exposed configuration for this in the UI, let's provide a default value here (like output.csv or similar).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. Added a default

default="output.csv",
description="The relative file path destination for the modified csv file",
)

input_model = DynamicBatchTransformSchemaModel

@mockable(
name=resource.name,
description=resource.description,
input_schema=input_model.model_json_schema(),
output_schema=output_model.model_json_schema(),
example_calls=[], # Examples cannot be provided for context.
)
async def context_tool_fn(
query: str, destination_path: str = "output.csv"
) -> dict[str, Any]:
# TODO: storage_bucket_folder_path_prefix support
return interrupt(
CreateBatchTransform(
name=f"task-{uuid.uuid4()}",
index_name=index_name,
prompt=query,
destination_path=destination_path,
index_folder_path=index_folder_path,
enable_web_search_grounding=enable_web_search_grounding,
output_columns=batch_transform_output_columns,
)
)

return StructuredToolWithOutputType(
name=tool_name,
Expand All @@ -199,5 +308,9 @@ async def context_tool_fn(destination_path: str) -> dict[str, Any]:
def ensure_valid_fields(resource_config: AgentContextResourceConfig):
if not resource_config.settings.query:
raise ValueError("Query object is required")
if not resource_config.settings.query.value:
raise ValueError("Query prompt is required")

if not resource_config.settings.query.variant:
raise ValueError("Query variant is required")

if is_static_query(resource_config) and not resource_config.settings.query.value:
raise ValueError("Static query requires a query value to be set")
Loading
Loading