-
Notifications
You must be signed in to change notification settings - Fork 29
fix dynamic and static options #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| """Context tool creation for semantic index retrieval.""" | ||
|
|
||
| import uuid | ||
| from typing import Any | ||
| from typing import Any, Optional, Type | ||
|
|
||
| from langchain_core.documents import Document | ||
| from langchain_core.tools import StructuredTool | ||
|
|
@@ -26,6 +26,13 @@ | |
| from .utils import sanitize_tool_name | ||
|
|
||
|
|
||
| def is_static_query(resource: AgentContextResourceConfig) -> bool: | ||
| """Check if the resource configuration uses a static query variant.""" | ||
| if resource.settings.query is None or resource.settings.query.variant is None: | ||
| return False | ||
| return resource.settings.query.variant.lower() == "static" | ||
|
|
||
|
|
||
| def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool: | ||
| tool_name = sanitize_tool_name(resource.name) | ||
| retrieval_mode = resource.settings.retrieval_mode.lower() | ||
|
|
@@ -40,34 +47,58 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool: | |
| def handle_semantic_search( | ||
| tool_name: str, resource: AgentContextResourceConfig | ||
| ) -> StructuredTool: | ||
| ensure_valid_fields(resource) | ||
|
|
||
| # needed for type checking | ||
| assert resource.settings.query is not None | ||
| assert resource.settings.query.variant is not None | ||
|
|
||
| retriever = ContextGroundingRetriever( | ||
| index_name=resource.index_name, | ||
| folder_path=resource.folder_path, | ||
| number_of_results=resource.settings.result_count, | ||
| ) | ||
|
|
||
| class ContextInputSchemaModel(BaseModel): | ||
| query: str = Field( | ||
| ..., description="The query to search for in the knowledge base" | ||
| ) | ||
|
|
||
| class ContextOutputSchemaModel(BaseModel): | ||
| documents: list[Document] = Field( | ||
| ..., description="List of retrieved documents." | ||
| ) | ||
|
|
||
| input_model = ContextInputSchemaModel | ||
| output_model = ContextOutputSchemaModel | ||
|
|
||
| @mockable( | ||
| name=resource.name, | ||
| description=resource.description, | ||
| input_schema=input_model.model_json_schema(), | ||
| output_schema=output_model.model_json_schema(), | ||
| example_calls=[], # Examples cannot be provided for context. | ||
| ) | ||
| async def context_tool_fn(query: str) -> dict[str, Any]: | ||
| return {"documents": await retriever.ainvoke(query)} | ||
| if is_static_query(resource): | ||
| static_query_value = resource.settings.query.value | ||
| assert static_query_value is not None | ||
| input_model = None | ||
|
|
||
| @mockable( | ||
| name=resource.name, | ||
| description=resource.description, | ||
| input_schema=input_model, | ||
| output_schema=output_model.model_json_schema(), | ||
| example_calls=[], # Examples cannot be provided for context. | ||
| ) | ||
| async def context_tool_fn() -> dict[str, Any]: | ||
| return {"documents": await retriever.ainvoke(static_query_value)} | ||
|
|
||
| else: | ||
| # Dynamic query - requires query parameter | ||
| class ContextInputSchemaModel(BaseModel): | ||
| query: str = Field( | ||
| ..., description="The query to search for in the knowledge base" | ||
| ) | ||
|
|
||
| input_model = ContextInputSchemaModel | ||
|
|
||
| @mockable( | ||
| name=resource.name, | ||
| description=resource.description, | ||
| input_schema=input_model.model_json_schema(), | ||
| output_schema=output_model.model_json_schema(), | ||
| example_calls=[], # Examples cannot be provided for context. | ||
| ) | ||
| async def context_tool_fn(query: str) -> dict[str, Any]: | ||
| return {"documents": await retriever.ainvoke(query)} | ||
|
|
||
| return StructuredToolWithOutputType( | ||
| name=tool_name, | ||
|
|
@@ -82,36 +113,69 @@ def handle_deep_rag( | |
| tool_name: str, resource: AgentContextResourceConfig | ||
| ) -> StructuredTool: | ||
| ensure_valid_fields(resource) | ||
|
|
||
| # needed for type checking | ||
| assert resource.settings.query is not None | ||
| assert resource.settings.query.value is not None | ||
| assert resource.settings.query.variant is not None | ||
|
|
||
| index_name = resource.index_name | ||
| prompt = resource.settings.query.value | ||
| if not resource.settings.citation_mode: | ||
| raise ValueError("Citation mode is required for Deep RAG") | ||
| citation_mode = CitationMode(resource.settings.citation_mode.value) | ||
|
|
||
| input_model = None | ||
| output_model = DeepRagResponse | ||
|
|
||
| @mockable( | ||
| name=resource.name, | ||
| description=resource.description, | ||
| input_schema=input_model, | ||
| output_schema=output_model.model_json_schema(), | ||
| example_calls=[], # Examples cannot be provided for context. | ||
| ) | ||
| async def context_tool_fn() -> dict[str, Any]: | ||
| # TODO: add glob pattern support | ||
| return interrupt( | ||
| CreateDeepRag( | ||
| name=f"task-{uuid.uuid4()}", | ||
| index_name=index_name, | ||
| prompt=prompt, | ||
| citation_mode=citation_mode, | ||
| if is_static_query(resource): | ||
| # Static query - no input parameter needed | ||
| static_prompt = resource.settings.query.value | ||
| assert static_prompt is not None | ||
| input_model = None | ||
|
|
||
| @mockable( | ||
| name=resource.name, | ||
| description=resource.description, | ||
| input_schema=input_model, | ||
| output_schema=output_model.model_json_schema(), | ||
| example_calls=[], # Examples cannot be provided for context. | ||
| ) | ||
| async def context_tool_fn() -> dict[str, Any]: | ||
| # TODO: add glob pattern support | ||
| return interrupt( | ||
| CreateDeepRag( | ||
| name=f"task-{uuid.uuid4()}", | ||
| index_name=index_name, | ||
| prompt=static_prompt, | ||
| citation_mode=citation_mode, | ||
| ) | ||
| ) | ||
|
|
||
| else: | ||
| # Dynamic query - requires query parameter | ||
| class DeepRagInputSchemaModel(BaseModel): | ||
| query: str = Field( | ||
| ..., | ||
| description="Describe the task: what to research across documents, what to synthesize, and how to cite sources", | ||
| ) | ||
|
|
||
| input_model = DeepRagInputSchemaModel | ||
|
|
||
| @mockable( | ||
| name=resource.name, | ||
| description=resource.description, | ||
| input_schema=input_model.model_json_schema(), | ||
| output_schema=output_model.model_json_schema(), | ||
| example_calls=[], # Examples cannot be provided for context. | ||
| ) | ||
| async def context_tool_fn(query: str) -> dict[str, Any]: | ||
| # TODO: add glob pattern support | ||
| return interrupt( | ||
| CreateDeepRag( | ||
| name=f"task-{uuid.uuid4()}", | ||
| index_name=index_name, | ||
| prompt=query, | ||
| citation_mode=citation_mode, | ||
| ) | ||
| ) | ||
|
|
||
| return StructuredToolWithOutputType( | ||
| name=tool_name, | ||
|
|
@@ -129,11 +193,9 @@ def handle_batch_transform( | |
|
|
||
| # needed for type checking | ||
| assert resource.settings.query is not None | ||
| assert resource.settings.query.value is not None | ||
| assert resource.settings.query.variant is not None | ||
|
|
||
| index_name = resource.index_name | ||
| prompt = resource.settings.query.value | ||
|
|
||
| index_folder_path = resource.folder_path | ||
| if not resource.settings.web_search_grounding: | ||
| raise ValueError("Web search grounding field is required for Batch Transform") | ||
|
|
@@ -157,35 +219,82 @@ def handle_batch_transform( | |
| ) | ||
| ) | ||
|
|
||
| class BatchTransformSchemaModel(BaseModel): | ||
| destination_path: str = Field( | ||
| ..., | ||
| description="The relative file path destination for the modified csv file", | ||
| ) | ||
|
|
||
| input_model = BatchTransformSchemaModel | ||
| output_model = BatchTransformResponse | ||
|
|
||
| @mockable( | ||
| name=resource.name, | ||
| description=resource.description, | ||
| input_schema=input_model.model_json_schema(), | ||
| output_schema=output_model.model_json_schema(), | ||
| example_calls=[], # Examples cannot be provided for context. | ||
| ) | ||
| async def context_tool_fn(destination_path: str) -> dict[str, Any]: | ||
| # TODO: storage_bucket_folder_path_prefix support | ||
| return interrupt( | ||
| CreateBatchTransform( | ||
| name=f"task-{uuid.uuid4()}", | ||
| index_name=index_name, | ||
| prompt=prompt, | ||
| destination_path=destination_path, | ||
| index_folder_path=index_folder_path, | ||
| enable_web_search_grounding=enable_web_search_grounding, | ||
| output_columns=batch_transform_output_columns, | ||
| input_model: Optional[Type[BaseModel]] | ||
|
|
||
| if is_static_query(resource): | ||
| # Static query - only destination_path parameter needed | ||
| static_prompt = resource.settings.query.value | ||
| assert static_prompt is not None | ||
|
|
||
| class StaticBatchTransformSchemaModel(BaseModel): | ||
| destination_path: str = Field( | ||
| default="output.csv", | ||
| description="The relative file path destination for the modified csv file", | ||
| ) | ||
|
|
||
| input_model = StaticBatchTransformSchemaModel | ||
|
|
||
| @mockable( | ||
| name=resource.name, | ||
| description=resource.description, | ||
| input_schema=input_model.model_json_schema(), | ||
| output_schema=output_model.model_json_schema(), | ||
| example_calls=[], # Examples cannot be provided for context. | ||
| ) | ||
| async def context_tool_fn( | ||
| destination_path: str = "output.csv", | ||
| ) -> dict[str, Any]: | ||
| # TODO: storage_bucket_folder_path_prefix support | ||
| return interrupt( | ||
| CreateBatchTransform( | ||
| name=f"task-{uuid.uuid4()}", | ||
| index_name=index_name, | ||
| prompt=static_prompt, | ||
| destination_path=destination_path, | ||
| index_folder_path=index_folder_path, | ||
| enable_web_search_grounding=enable_web_search_grounding, | ||
| output_columns=batch_transform_output_columns, | ||
| ) | ||
| ) | ||
|
|
||
| else: | ||
| # Dynamic query - requires both query and destination_path parameters | ||
| class DynamicBatchTransformSchemaModel(BaseModel): | ||
| query: str = Field( | ||
| ..., | ||
| description="Describe the task for each row: what to analyze, what to extract, and how to populate the output columns", | ||
| ) | ||
| destination_path: str = Field( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this going to be configurable from UI? if not we should have a default value here. the LLM might not know what destination_path to provide and refuse to process the request
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually encountered this earlier. If there is not going to be an exposed configuration for this in the UI, let's provide a default value here (like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree. Added a default |
||
| default="output.csv", | ||
| description="The relative file path destination for the modified csv file", | ||
| ) | ||
|
|
||
| input_model = DynamicBatchTransformSchemaModel | ||
|
|
||
| @mockable( | ||
| name=resource.name, | ||
| description=resource.description, | ||
| input_schema=input_model.model_json_schema(), | ||
| output_schema=output_model.model_json_schema(), | ||
| example_calls=[], # Examples cannot be provided for context. | ||
| ) | ||
| async def context_tool_fn( | ||
| query: str, destination_path: str = "output.csv" | ||
| ) -> dict[str, Any]: | ||
| # TODO: storage_bucket_folder_path_prefix support | ||
| return interrupt( | ||
| CreateBatchTransform( | ||
| name=f"task-{uuid.uuid4()}", | ||
| index_name=index_name, | ||
| prompt=query, | ||
| destination_path=destination_path, | ||
| index_folder_path=index_folder_path, | ||
| enable_web_search_grounding=enable_web_search_grounding, | ||
| output_columns=batch_transform_output_columns, | ||
| ) | ||
| ) | ||
|
|
||
| return StructuredToolWithOutputType( | ||
| name=tool_name, | ||
|
|
@@ -199,5 +308,9 @@ async def context_tool_fn(destination_path: str) -> dict[str, Any]: | |
| def ensure_valid_fields(resource_config: AgentContextResourceConfig): | ||
| if not resource_config.settings.query: | ||
| raise ValueError("Query object is required") | ||
| if not resource_config.settings.query.value: | ||
| raise ValueError("Query prompt is required") | ||
|
|
||
| if not resource_config.settings.query.variant: | ||
| raise ValueError("Query variant is required") | ||
|
|
||
| if is_static_query(resource_config) and not resource_config.settings.query.value: | ||
| raise ValueError("Static query requires a query value to be set") | ||

Uh oh!
There was an error while loading. Please reload this page.