Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/_ravnar/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi import Depends

from _ravnar import schema
from _ravnar.auth import User
from _ravnar.config import StorageConfig

from .agents import make_router as make_agents_router
Expand All @@ -20,14 +21,20 @@ def make_router(
*,
storage_config: StorageConfig,
agent_handler: AgentHandler,
authenticated_user: Callable[..., Any],
authorized_user_with: Callable[..., Any],
) -> schema.APIRouter:
router = schema.APIRouter(tags=["API"], dependencies=[Depends(authenticated_user)])
router = schema.APIRouter(
tags=["API"],
# This ensures that every endpoint on this router or its sub-routers
# can only be accessed by authenticated users. The actual authorization
# check happens on the specific endpoint.
dependencies=[Depends(authorized_user_with())],
)

@router.get("/user")
async def get_user(
user: schema.User = Depends(authenticated_user), # noqa: B008
) -> schema.User:
user: User = Depends(authorized_user_with()), # noqa: B008
) -> User:
return user

@router.get("/config")
Expand All @@ -42,12 +49,12 @@ async def get_config() -> schema.APIConfig:
_make_stateful_router(
storage_config=storage_config,
agent_handler=agent_handler,
authenticated_user=authenticated_user,
authorized_user_with=authorized_user_with,
)
)

router.include_router(
make_agents_router(agent_handler=agent_handler, authenticated_user=authenticated_user), prefix="/agents"
make_agents_router(agent_handler=agent_handler, authorized_user_with=authorized_user_with), prefix="/agents"
)

return router
Expand All @@ -57,7 +64,7 @@ def _make_stateful_router(
*,
storage_config: StorageConfig,
agent_handler: AgentHandler,
authenticated_user: Callable[..., Any],
authorized_user_with: Callable[..., Any],
) -> schema.APIRouter:
from _ravnar.database import Database
from _ravnar.file_storage import FileHandler
Expand All @@ -72,15 +79,15 @@ def _make_stateful_router(
)

router.include_router(
make_files_router(file_handler=file_handler, authenticated_user=authenticated_user),
make_files_router(file_handler=file_handler, authorized_user_with=authorized_user_with),
prefix="/files",
)
router.include_router(
make_threads_router(
database=database,
file_handler=file_handler,
agent_handler=agent_handler,
authenticated_user=authenticated_user,
authorized_user_with=authorized_user_with,
),
prefix="/threads",
)
Expand Down
28 changes: 20 additions & 8 deletions src/_ravnar/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,32 @@
from fastapi import Depends, Path

from _ravnar import schema
from _ravnar.auth import User

if TYPE_CHECKING:
from _ravnar.core import AgentHandler


def make_router(*, agent_handler: AgentHandler, authenticated_user: Callable[..., Any]) -> schema.APIRouter:
router = schema.APIRouter(tags=["Agents"], dependencies=[Depends(authenticated_user)])
def make_router(*, agent_handler: AgentHandler, authorized_user_with: Callable[..., Any]) -> schema.APIRouter:
router = schema.APIRouter(tags=["Agents"])

@router.get("")
async def list_agents() -> list[schema.AgentInfo]:
async def list_agents(
user: User = Depends(authorized_user_with("agents:read")), # noqa: B008
) -> list[schema.AgentInfo]:
return agent_handler.infos()

@router.sse("/{agentId}/run", methods=["POST"], response_model=schema.Event, tags=["Runs"])
async def create_stateless_run(
*, agent_id: Annotated[str, Path(alias="agentId")], run_agent_input: ag_ui.core.RunAgentInput
*,
agent_id: Annotated[str, Path(alias="agentId")],
run_agent_input: ag_ui.core.RunAgentInput,
user: User = Depends(authorized_user_with("agents:read")), # noqa: B008
) -> fastsse.Response:
return await agent_handler.run(agent_id, run_agent_input)

if agent_handler.dynamic_enabled:
_make_dynamic_agents_router(router, agent_handler=agent_handler, authenticated_user=authenticated_user)
_make_dynamic_agents_router(router, agent_handler=agent_handler, authorized_user_with=authorized_user_with)

return router

Expand All @@ -36,15 +42,18 @@ def _make_dynamic_agents_router(
router: schema.APIRouter,
*,
agent_handler: AgentHandler,
authenticated_user: Callable[..., Any],
authorized_user_with: Callable[..., Any],
) -> None:
description = (
"Only available if dynamic agents are enabled. "
"Can be checked with [`GET /api/config`](#/API/get_config_api_config_get)."
)

@router.post("", description=description)
async def register_agent(data: schema.RegisterAgentData) -> schema.AgentInfo:
async def register_agent(
data: schema.RegisterAgentData,
user: User = Depends(authorized_user_with("agents:write")), # noqa: B008
) -> schema.AgentInfo:
agent = data.agent()
await agent_handler.add_agent(data.id, agent)
return schema.AgentInfo(
Expand All @@ -54,5 +63,8 @@ async def register_agent(data: schema.RegisterAgentData) -> schema.AgentInfo:
)

@router.delete("/{agentId}", description=description)
async def unregister_agent(agent_id: Annotated[str, Path(alias="agentId")]) -> None:
async def unregister_agent(
agent_id: Annotated[str, Path(alias="agentId")],
user: User = Depends(authorized_user_with("agents:delete")), # noqa: B008
) -> None:
await agent_handler.remove_agent(agent_id)
12 changes: 6 additions & 6 deletions src/_ravnar/api/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@

from fastapi import APIRouter, Body, Depends, Response

from _ravnar import schema
from _ravnar.auth import User
from _ravnar.file_storage import FileHandler, FileInputContent, convert_file_to_input_content


def make_router(*, file_handler: FileHandler, authenticated_user: Callable[..., Any]) -> APIRouter:
def make_router(*, file_handler: FileHandler, authorized_user_with: Callable[..., Any]) -> APIRouter:
router = APIRouter(tags=["Files"])

@router.post("")
async def upload_file(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("files:write")), # noqa: B008
file_input_content: Annotated[FileInputContent, Body()],
) -> FileInputContent:
file, _ = await file_handler.add(file_input_content, user_id=user.id)
Expand All @@ -25,15 +25,15 @@ async def upload_file(
@router.get("/{id}")
async def get_file(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("files:read")), # noqa: B008
id: uuid.UUID,
) -> FileInputContent:
return convert_file_to_input_content(await file_handler.get(id, user_id=user.id))

@router.get("/{id}/content")
async def read_file(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("files:read")), # noqa: B008
id: uuid.UUID,
) -> Response:
mime_type, content = await file_handler.read(id, user_id=user.id)
Expand All @@ -46,7 +46,7 @@ async def read_file(
@router.delete("/{id}")
async def delete_file(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("files:delete")), # noqa: B008
id: uuid.UUID,
) -> None:
await file_handler.delete(id, user_id=user.id)
Expand Down
30 changes: 16 additions & 14 deletions src/_ravnar/api/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from opentelemetry import trace

from _ravnar import schema
from _ravnar.auth import User, assert_permissions
from _ravnar.file_storage import FileHandler, WrappedMetadata
from _ravnar.observability import traced
from _ravnar.utils import as_awaitable
Expand All @@ -35,14 +36,14 @@ def make_router(
database: Database,
file_handler: FileHandler,
agent_handler: AgentHandler,
authenticated_user: Callable[..., Any],
authorized_user_with: Callable[..., Any],
) -> schema.APIRouter:
router = schema.APIRouter(tags=["Threads"], dependencies=[Depends(authenticated_user)])
router = schema.APIRouter(tags=["Threads"])

@router.post("")
async def create_thread(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:write")), # noqa: B008
data: schema.CreateThreadData,
) -> schema.Thread:
agent_handler.assert_available(data.agent_id)
Expand All @@ -54,7 +55,7 @@ async def create_thread(
@router.get("")
async def get_threads(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:read")), # noqa: B008
pagination: Annotated[schema.Pagination[ThreadsSortBy], Query()],
) -> schema.Page[schema.Thread]:
return schema.Page[schema.Thread].model_validate(
Expand All @@ -64,22 +65,22 @@ async def get_threads(
@router.get("/{threadId}")
async def get_thread(
id: Annotated[str, Path(alias="threadId")],
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:read")), # noqa: B008
) -> schema.Thread:
return schema.Thread.model_validate(await database.get_thread(user_id=user.id, id=id), from_attributes=True)

@router.get("/{threadId}/messages")
async def get_thread_messages(
thread_id: Annotated[str, Path(alias="threadId")],
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:read")), # noqa: B008
) -> list[schema.AugmentedMessage]:
_, _, messages = await database.get_thread_history(user_id=user.id, thread_id=thread_id, run_id=None)
return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True)

@router.get("/{threadId}/runs")
async def get_runs(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:read")), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
pagination: Annotated[schema.Pagination[RunsSortBy], Query()],
) -> schema.Page[schema.Run]:
Expand All @@ -91,7 +92,7 @@ async def get_runs(
@router.get("/{threadId}/runs/{runId}")
async def get_run(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:read")), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
run_id: Annotated[str, Path(alias="runId")],
) -> schema.Run:
Expand All @@ -100,7 +101,7 @@ async def get_run(
@router.get("/{threadId}/runs/{runId}/messages")
async def get_run_messages(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:read")), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
run_id: Annotated[str, Path(alias="runId")],
) -> list[schema.AugmentedMessage]:
Expand All @@ -110,7 +111,7 @@ async def get_run_messages(
@router.sse("/{threadId}/runs", methods=["POST"], response_model=schema.Event, tags=["Runs"])
async def create_run(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:write")), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
data: schema.CreateRunData,
) -> fastsse.Response:
Expand Down Expand Up @@ -147,9 +148,10 @@ async def callback(event_processor: EventProcessor) -> None:
async def hydrate_files(
messages: list[schema.AugmentedMessage],
*,
user: schema.User,
user: User,
file_handler: FileHandler,
) -> None:
assert_permissions(user, "files:read", "files:write")
for m in messages:
if not isinstance(m, schema.AugmentedUserMessage):
continue
Expand All @@ -173,7 +175,7 @@ async def hydrate_files(
@router.post("/{threadId}/rename")
async def rename_thread(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:write")), # noqa: B008
id: Annotated[str, Path(alias="threadId")],
data: schema.RenameThreadData,
) -> schema.Thread:
Expand All @@ -184,15 +186,15 @@ async def rename_thread(
@router.delete("")
async def delete_threads(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:delete")), # noqa: B008
data: schema.DeleteThreadsData,
) -> None:
await database.delete_threads(user_id=user.id, ids=data.ids)

@router.delete("/{threadId}")
async def delete_thread(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
user: User = Depends(authorized_user_with("threads:delete")), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
) -> None:
await database.delete_threads(user_id=user.id, ids=[thread_id])
Expand Down
Loading
Loading