Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 72 additions & 9 deletions bot/tests/test_openviking_api_key_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from types import SimpleNamespace

import pytest

from vikingbot.agent.tools.ov_file import VikingSearchTool
from vikingbot.agent.tools.ov_file import VikingGrepTool, VikingSearchTool
from vikingbot.config.schema import SessionKey
from vikingbot.hooks.base import HookContext
from vikingbot.hooks.builtins.openviking_hooks import OpenVikingCompactHook
Expand Down Expand Up @@ -38,6 +37,15 @@ def session(self, _session_id):
async def admin_list_accounts(self):
return []

async def find(self, *_args, **_kwargs):
return []

async def search(self, *_args, **_kwargs):
return {"memories": [], "resources": [], "skills": []}

async def grep(self, *_args, **_kwargs):
return {"matches": []}

async def close(self):
return None

Expand Down Expand Up @@ -140,6 +148,9 @@ class _FakeClient:
def __init__(self):
self.calls = []

def should_sender_fanout(self):
return False

async def commit(self, session_id, messages, user_id=None):
self.calls.append((session_id, user_id, len(messages)))
return {"success": "committed"}
Expand Down Expand Up @@ -263,6 +274,49 @@ async def _accounts():
)


def test_openviking_grep_schema_requires_single_string_pattern():
tool = VikingGrepTool()

assert tool.parameters["properties"]["pattern"]["type"] == "string"


@pytest.mark.asyncio
async def test_openviking_grep_passes_admin_user_id(monkeypatch):
tool = VikingGrepTool()
calls = []

class _FakeClient:
admin_user_id = "admin"

async def grep(self, uri, pattern, case_insensitive=False, user_id=None):
calls.append((uri, pattern, case_insensitive, user_id))
return {
"matches": [
{
"uri": "viking://resources/doc.md",
"line": 3,
"content": "hello admin scoped grep",
}
]
}

async def _fake_get_client(_tool_context):
return _FakeClient()

monkeypatch.setattr(tool, "_get_client", _fake_get_client)

result = await tool.execute(
SimpleNamespace(workspace_id="workspace"),
uri="viking://resources/",
pattern="hello",
case_insensitive=True,
)

assert calls == [("viking://resources/", "hello", True, "admin")]
assert "Found 1 match for pattern 'hello':" in result
assert "viking://resources/doc.md" in result


@pytest.mark.asyncio
async def test_openviking_search_uses_policy_scoped_user_namespace(monkeypatch):
monkeypatch.setattr(ov_server_module, "load_config", lambda: _make_config("root"))
Expand All @@ -280,13 +334,17 @@ async def _accounts():
}
]

async def _search(query, target_uri=None, limit=20):
async def _search(query, target_uri=None, limit=20, user_id=None):
calls.append(target_uri)
return {"memories": [{"uri": target_uri, "abstract": "a", "score": 0.9, "is_leaf": True}]}

async def _fake_get_client(_tool_context):
return client

monkeypatch.setattr(client.client, "admin_list_accounts", _accounts)
monkeypatch.setattr(client.client, "search", _search)
monkeypatch.setattr(tool, "_get_client", lambda _tool_context: client)
monkeypatch.setattr(client, "search", _search)
monkeypatch.setattr(tool, "_get_client", _fake_get_client)
await client._load_namespace_policy()

tool_context = SimpleNamespace(workspace_id="workspace", memory_user_ids=["sender-1"])
result = await tool.execute(tool_context, query="hello")
Expand All @@ -303,14 +361,19 @@ async def test_openviking_search_user_key_mode_uses_current_user_namespace(monke

calls = []

async def _search(query, target_uri=None, limit=20):
async def _search(query, target_uri=None, limit=20, user_id=None):
calls.append(target_uri)
return {"memories": [{"uri": target_uri, "abstract": "a", "score": 0.9, "is_leaf": True}]}

monkeypatch.setattr(client.client, "search", _search)
monkeypatch.setattr(tool, "_get_client", lambda _tool_context: client)
async def _fake_get_client(_tool_context):
return client

monkeypatch.setattr(client, "search", _search)
monkeypatch.setattr(tool, "_get_client", _fake_get_client)

tool_context = SimpleNamespace(workspace_id="workspace", memory_user_ids=["sender-1", "sender-2"])
tool_context = SimpleNamespace(
workspace_id="workspace", memory_user_ids=["sender-1", "sender-2"]
)
result = await tool.execute(tool_context, query="hello")

assert "viking://user/memories/" in result
Expand Down
149 changes: 67 additions & 82 deletions bot/vikingbot/agent/tools/ov_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from abc import ABC
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Optional

import httpx
from loguru import logger
Expand Down Expand Up @@ -89,6 +89,7 @@ def name(self) -> str:
def description(self) -> str:
return (
"Using query to search for resources (knowledge, code, files, workflow, etc.) in OpenViking. "
"Result: Only URIs and summaries are included here. To view the full content, use openviking_multi_read tool."
"This operation performs semantic retrieval, not full character matching. Please avoid repeated calls with similar queries as much as possible."
"bad-case: after searching with ‘Nate Joanna dog playdate 3:00 pm', another search was performed using 'Nate Joanna dog playdate'."
)
Expand Down Expand Up @@ -249,37 +250,38 @@ async def execute(
) -> str:
try:
client = await self._get_client(tool_context)
search_client = client.admin_user_client or client.client
admin_user_id = client.admin_user_id

# If no target_uri specified, use memory_user_ids to search specific user memories
if not target_uri and tool_context.memory_user_ids:
all_results = []
user_ids = tool_context.memory_user_ids
if client._is_user_key_mode():
user_ids = [None]

for user_id in user_ids:
user_uri = client._memory_target_uri(user_id)
logger.info(f"openviking_search: searching {user_uri} for query: {query}")
results = await search_client.search(query, target_uri=user_uri, limit=20)
if results:
memories = [
item
for item in self._extract_search_items(results)
if item.get("type") == "memory"
]
all_results.extend(memories)

if not all_results:
return f"No results found for query: {query}"
user_ids = tool_context.memory_user_ids if client.should_sender_fanout() else [None]
grouped_items = {
"memory": [],
"resource": [],
"skill": [],
}

for memory_user_id in user_ids:
results = await client.search(
query,
target_uri=client._memory_target_uri(memory_user_id),
limit=20,
user_id=admin_user_id,
)
filtered_items = self._filter_search_items(results, min_score=min_score)
for item_type, items in filtered_items.items():
grouped_items[item_type].extend(items)

grouped_items = self._filter_search_items(all_results, min_score=min_score)
total = sum(len(items) for items in grouped_items.values())
if total == 0:
return f"No results found for query: {query}"
return self._format_search_items_json(grouped_items, min_score=min_score)

results = await search_client.search(query, target_uri=target_uri, limit=20)
results = await client.search(
query,
target_uri=target_uri,
limit=20,
user_id=admin_user_id,
)

if not results:
return f"No results found for query: {query}"
Expand Down Expand Up @@ -352,7 +354,7 @@ async def execute(


class VikingGrepTool(OVFileTool):
"""Tool to search Viking resources using regex patterns."""
"""Tool to search Viking resources using a regex pattern."""

@property
def name(self) -> str:
Expand All @@ -361,7 +363,8 @@ def name(self) -> str:
@property
def description(self) -> str:
return (
"Search Viking resources using regex patterns (like grep). Supports multiple patterns to search concurrently."
"Search Viking resources using a regex pattern (like grep)."
"Result: Only URIs and summaries are included here. To view the full content, use openviking_multi_read tool."
"Please avoid repeated calls with similar queries as much as possible."
)

Expand All @@ -375,9 +378,8 @@ def parameters(self) -> dict[str, Any]:
"description": "The whole Viking URI to search within (e.g., viking://resources/)",
},
"pattern": {
"type": "array",
"items": {"type": "string"},
"description": "Regex pattern or array of regex patterns to search for",
"type": "string",
"description": "Regex pattern to search for",
},
"case_insensitive": {
"type": "boolean",
Expand All @@ -392,71 +394,51 @@ async def execute(
self,
tool_context: "ToolContext",
uri: str,
pattern: Union[str, list[str]],
pattern: str,
case_insensitive: bool = False,
**kwargs: Any,
) -> str:
try:
client = await self._get_client(tool_context)
patterns = [pattern] if isinstance(pattern, str) else pattern
result = await client.grep(
uri,
pattern,
case_insensitive=case_insensitive,
user_id=client.admin_user_id,
)
if isinstance(result, dict):
matches = result.get("matches", [])
else:
matches = getattr(result, "matches", [])

# Limit concurrent requests to avoid overwhelming the server and memory
max_concurrent = 10
semaphore = asyncio.Semaphore(max_concurrent)
if not matches:
return f"No matches found for pattern: '{pattern}'"

async def run_grep(p: str) -> tuple[str, list[Any]]:
async with semaphore:
try:
result = await client.grep(uri, p, case_insensitive=case_insensitive)
if isinstance(result, dict):
matches = result.get("matches", [])
else:
matches = getattr(result, "matches", [])
return (p, matches)
except Exception as e:
logger.warning(f"Error searching for pattern '{p}': {e}")
return (p, [])
merged_results: dict[str, list[tuple[int, str]]] = {}

tasks = [run_grep(p) for p in patterns]
results = await asyncio.gather(*tasks)
for match in matches:
if isinstance(match, dict):
match_uri = match.get("uri", "unknown")
line = match.get("line", "?")
content = match.get("content", "")
else:
match_uri = getattr(match, "uri", "unknown")
line = getattr(match, "line", "?")
content = getattr(match, "content", "")

# Merge results by URI
merged_results: dict[str, list[tuple[int, str, str]]] = {}
total_matches = 0
if match_uri not in merged_results:
merged_results[match_uri] = []
merged_results[match_uri].append((line, content))

for p, matches in results:
if not matches:
continue
total_matches += len(matches)
for match in matches:
if isinstance(match, dict):
match_uri = match.get("uri", "unknown")
line = match.get("line", "?")
content = match.get("content", "")
else:
match_uri = getattr(match, "uri", "unknown")
line = getattr(match, "line", "?")
content = getattr(match, "content", "")

if match_uri not in merged_results:
merged_results[match_uri] = []
merged_results[match_uri].append((line, content, p))

if not merged_results:
pattern_str = ", ".join(f"'{p}'" for p in patterns)
return f"No matches found for patterns: {pattern_str}"

# Format output
result_lines = [
f"Found {total_matches} match{'es' if total_matches != 1 else ''} across {len(patterns)} pattern{'s' if len(patterns) != 1 else ''}:"
f"Found {len(matches)} match{'es' if len(matches) != 1 else ''} for pattern '{pattern}':"
]

for match_uri, matches in merged_results.items():
# Sort matches by line number
matches.sort(key=lambda x: int(x[0]) if str(x[0]).isdigit() else 0)
for match_uri, uri_matches in merged_results.items():
uri_matches.sort(key=lambda x: int(x[0]) if str(x[0]).isdigit() else 0)
result_lines.append(f"\n📄 {match_uri}")
for line, content, pattern_name in matches:
result_lines.append(f" Line {line} (pattern: '{pattern_name}'):")
for line, content in uri_matches:
result_lines.append(f" Line {line}:")
result_lines.append(f" {content}")

return "\n".join(result_lines)
Expand All @@ -473,7 +455,10 @@ def name(self) -> str:

@property
def description(self) -> str:
return "Find Viking resources using glob patterns (like **/*.md, *.py)."
return (
"Find Viking resources using glob patterns (like **/*.md, *.py)."
"Result: Only URIs and summaries are included here. To view the full content, use openviking_multi_read tool."
)

@property
def parameters(self) -> dict[str, Any]:
Expand All @@ -498,7 +483,7 @@ async def execute(
) -> str:
try:
client = await self._get_client(tool_context)
result = await client.glob(pattern, uri=uri or None)
result = await client.glob(pattern, uri=uri or None, user_id=client.admin_user_id)

if isinstance(result, dict):
matches = result.get("matches", [])
Expand Down
Loading
Loading