Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions openviking/server/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ def _sanitize_floats(obj: Any) -> Any:

router = APIRouter(prefix="/api/v1/search", tags=["search"])
TimeField = Literal["updated_at", "created_at"]
FindMode = Literal["auto", "fast", "deep"]


def _resolve_retriever_mode(mode: FindMode) -> Optional[str]:
if mode == "auto":
return None

from openviking.retrieve.hierarchical_retriever import RetrieverMode

if mode == "fast":
return RetrieverMode.QUICK
return RetrieverMode.THINKING


def _resolve_search_limit(limit: int, node_limit: Optional[int]) -> int:
Expand Down Expand Up @@ -68,6 +80,7 @@ class FindRequest(BaseModel):
score_threshold: Optional[float] = None
filter: Optional[Dict[str, Any]] = None
include_provenance: bool = False
mode: FindMode = "auto"

since: Optional[str] = None
until: Optional[str] = None
Expand Down Expand Up @@ -126,17 +139,22 @@ async def find(
request.until,
request.time_field,
)
retriever_mode = _resolve_retriever_mode(request.mode)
find_kwargs = {
"query": request.query,
"ctx": _ctx,
"target_uri": request.target_uri,
"limit": actual_limit,
"score_threshold": request.score_threshold,
"filter": effective_filter,
}
if retriever_mode is not None:
find_kwargs["mode"] = retriever_mode

execution = await run_operation(
operation="search.find",
telemetry=request.telemetry,
fn=lambda: service.search.find(
query=request.query,
ctx=_ctx,
target_uri=request.target_uri,
limit=actual_limit,
score_threshold=request.score_threshold,
filter=effective_filter,
),
fn=lambda: service.search.find(**find_kwargs),
)
result = execution.result
if hasattr(result, "to_dict"):
Expand Down
22 changes: 14 additions & 8 deletions openviking/service/search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async def find(
limit: int = 10,
score_threshold: Optional[float] = None,
filter: Optional[Dict] = None,
mode: Optional[str] = None,
) -> Any:
"""Semantic search without session context.

Expand All @@ -98,18 +99,23 @@ async def find(
limit: Max results
score_threshold: Score threshold
filter: Metadata filters
mode: Optional retriever mode override

Returns:
FindResult
"""
_ensure_non_empty_query(query)
viking_fs = self._ensure_initialized()
result = await viking_fs.find(
query=query,
ctx=ctx,
target_uri=target_uri,
limit=limit,
score_threshold=score_threshold,
filter=filter,
)
find_kwargs = {
"query": query,
"ctx": ctx,
"target_uri": target_uri,
"limit": limit,
"score_threshold": score_threshold,
"filter": filter,
}
if mode is not None:
find_kwargs["mode"] = mode

result = await viking_fs.find(**find_kwargs)
return result
19 changes: 12 additions & 7 deletions openviking/storage/viking_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,7 @@ async def find(
score_threshold: Optional[float] = None,
filter: Optional[Dict] = None,
ctx: Optional[RequestContext] = None,
mode: Optional[str] = None,
):
"""Semantic search.

Expand All @@ -1015,6 +1016,7 @@ async def find(
limit: Return count
score_threshold: Score threshold
filter: Metadata filter
mode: Optional retriever mode override

Returns:
FindResult
Expand Down Expand Up @@ -1064,13 +1066,16 @@ async def find(
f"[VikingFS.find] Calling retriever.retrieve with ctx.account_id={real_ctx.account_id}, ctx.user={real_ctx.user}"
)

result = await retriever.retrieve(
typed_query,
ctx=real_ctx,
limit=limit,
score_threshold=score_threshold,
scope_dsl=filter,
)
retrieve_kwargs = {
"ctx": real_ctx,
"limit": limit,
"score_threshold": score_threshold,
"scope_dsl": filter,
}
if mode is not None:
retrieve_kwargs["mode"] = mode

result = await retriever.retrieve(typed_query, **retrieve_kwargs)

# Convert QueryResult to FindResult
memories, resources, skills = [], [], []
Expand Down
148 changes: 146 additions & 2 deletions tests/misc/test_vikingfs_find_without_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,27 @@
"""Regression test for VikingFS.find without rerank configuration."""

import contextvars
from unittest.mock import MagicMock
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import httpx
import pytest
from fastapi import FastAPI

from openviking.retrieve.hierarchical_retriever import RetrieverMode
from openviking.server.auth import get_request_context
from openviking.server.identity import RequestContext, Role
from openviking.server.routers import search as search_router
from openviking.service.search_service import SearchService
from openviking.storage.viking_fs import VikingFS
from openviking_cli.retrieve.types import ContextType, MatchedContext, QueryResult
from openviking_cli.session.user_id import UserIdentifier


def _ctx() -> RequestContext:
return RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER)
return RequestContext(
user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER
)


def _make_viking_fs() -> VikingFS:
Expand All @@ -32,6 +41,26 @@ def _make_viking_fs() -> VikingFS:
return fs


def _make_search_app(monkeypatch, captured: dict | None = None) -> FastAPI:
app = FastAPI()
app.include_router(search_router.router)
app.dependency_overrides[get_request_context] = _ctx

if captured is not None:

async def fake_find(**kwargs):
captured.update(kwargs)
return {"items": []}

monkeypatch.setattr(
search_router,
"get_service",
lambda: SimpleNamespace(search=SimpleNamespace(find=fake_find)),
)

return app


@pytest.mark.asyncio
async def test_find_works_without_rerank_config(monkeypatch) -> None:
fs = _make_viking_fs()
Expand Down Expand Up @@ -89,3 +118,118 @@ async def retrieve(self, typed_query, ctx, limit, score_threshold, scope_dsl):
assert captured["score_threshold"] == 0.2
assert captured["scope_dsl"] == {"category": "doc"}
fs._ensure_access.assert_called_once_with("viking://resources/docs", request_ctx)


@pytest.mark.parametrize(
("request_mode", "expected_mode"),
[
("fast", RetrieverMode.QUICK),
("deep", RetrieverMode.THINKING),
],
)
@pytest.mark.asyncio
async def test_find_endpoint_passes_mode_to_search_service(
monkeypatch, request_mode: str, expected_mode: str
) -> None:
captured = {}
app = _make_search_app(monkeypatch, captured)

async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://testserver",
) as client:
resp = await client.post(
"/api/v1/search/find",
json={"query": "sample", "mode": request_mode},
)

assert resp.status_code == 200
assert captured["mode"] == expected_mode


@pytest.mark.parametrize("payload", [{}, {"mode": "auto"}])
@pytest.mark.asyncio
async def test_find_endpoint_auto_mode_omits_retriever_mode(
monkeypatch, payload
) -> None:
captured = {}
app = _make_search_app(monkeypatch, captured)

async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://testserver",
) as client:
resp = await client.post(
"/api/v1/search/find",
json={"query": "sample", **payload},
)

assert resp.status_code == 200
assert "mode" not in captured


@pytest.mark.asyncio
async def test_find_endpoint_rejects_invalid_mode(monkeypatch) -> None:
app = _make_search_app(monkeypatch)

async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://testserver",
) as client:
resp = await client.post(
"/api/v1/search/find",
json={"query": "sample", "mode": "turbo"},
)

assert resp.status_code == 422
assert resp.json()["detail"]


@pytest.mark.asyncio
async def test_search_service_passes_find_mode_to_vikingfs() -> None:
viking_fs = MagicMock()
viking_fs.find = AsyncMock(return_value={"items": []})
service = SearchService(viking_fs)

await service.find("guide", ctx=_ctx(), mode=RetrieverMode.QUICK)

assert viking_fs.find.await_args.kwargs["mode"] == RetrieverMode.QUICK


@pytest.mark.asyncio
async def test_search_service_omits_find_mode_by_default() -> None:
viking_fs = MagicMock()
viking_fs.find = AsyncMock(return_value={"items": []})
service = SearchService(viking_fs)

await service.find("guide", ctx=_ctx())

assert "mode" not in viking_fs.find.await_args.kwargs


@pytest.mark.parametrize("mode", [RetrieverMode.QUICK, RetrieverMode.THINKING])
@pytest.mark.asyncio
async def test_find_passes_mode_to_retriever(monkeypatch, mode: str) -> None:
fs = _make_viking_fs()
captured = {}

class FakeRetriever:
def __init__(self, storage, embedder, rerank_config):
pass

async def retrieve(self, typed_query, **kwargs):
captured["mode"] = kwargs["mode"]
return QueryResult(
query=typed_query,
matched_contexts=[],
searched_directories=[],
)

monkeypatch.setattr(
"openviking.retrieve.hierarchical_retriever.HierarchicalRetriever",
FakeRetriever,
)

await fs.find("guide", ctx=_ctx(), mode=mode)

assert captured["mode"] == mode