Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ minversion = "9.0"
addopts = ["-ra", "--tb=short"]
testpaths = ["tests"]
filterwarnings = [
"error",
# "error",
"ignore::ResourceWarning",
# FIXME: isolate the root cause for this warning
"ignore:The '.*?' attribute with value '.*?' was provided:pydantic.warnings.UnsupportedFieldAttributeWarning",
Expand Down
47 changes: 36 additions & 11 deletions scripts/api_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def main():
user = assert_successful_response(client.get("/api/user")).json()
print(json.dumps(user, indent=2))

header("info")
config = assert_successful_response(client.get("/api/config")).json()
print(json.dumps(config, indent=2))
header("agents")
agent_infos = assert_successful_response(client.get("/api/agents")).json()
print(json.dumps(agent_infos, indent=2))

agent_ids = [a["id"] for a in config["agents"]]
agent_ids = [a["id"] for a in agent_infos]

header("new thread")
thread = assert_successful_response(
Expand All @@ -57,6 +57,8 @@ def main():

header("new run in same thread with frontend tool call")

root_run_id = str(uuid.uuid4())

class CheerInput(pydantic.BaseModel):
name: str = pydantic.Field(description="Name of the user")

Expand All @@ -66,8 +68,9 @@ class CheerInput(pydantic.BaseModel):
with httpx_sse.connect_sse(
client,
"POST",
f"/api/threads/{thread['id']}/run",
f"/api/threads/{thread['id']}/runs",
json={
"id": root_run_id,
"messages": [
{
"id": str(uuid.uuid4()),
Expand Down Expand Up @@ -103,14 +106,11 @@ class CheerInput(pydantic.BaseModel):
)
print(tool_call.model_dump_json(indent=2))

thread = assert_successful_response(client.get(f"/api/threads/{thread['id']}")).json()
print(json.dumps(thread, indent=2))

header("new run in same thread with frontend tool call result")
with httpx_sse.connect_sse(
client,
"POST",
f"/api/threads/{thread['id']}/run",
f"/api/threads/{thread['id']}/runs",
json={
"messages": [
{
Expand All @@ -130,10 +130,35 @@ class CheerInput(pydantic.BaseModel):
thread = assert_successful_response(client.get(f"/api/threads/{thread['id']}")).json()
print(json.dumps(thread, indent=2))

header("list threads")
thread = assert_successful_response(client.get("/api/threads", params={"sortBy": "createdAt"})).json()
header("new run in same thread branching from root run with different frontend tool call result")
with httpx_sse.connect_sse(
client,
"POST",
f"/api/threads/{thread['id']}/runs",
json={
"parentRunId": root_run_id,
"messages": [
{
"id": str(uuid.uuid4()),
"content": json.dumps({"successful": False}),
"role": "tool",
"toolCallId": tool_call_id,
},
],
},
) as event_source:
assert_successful_response(event_source.response)
for sse in event_source.iter_sse():
event = sse.json()
print(json.dumps(event, indent=2))

thread = assert_successful_response(client.get(f"/api/threads/{thread['id']}")).json()
print(json.dumps(thread, indent=2))

header("list messages")
messages = assert_successful_response(client.get(f"/api/threads/{thread['id']}/messages")).json()
print(json.dumps(messages, indent=2))


if __name__ == "__main__":
main()
78 changes: 54 additions & 24 deletions src/_ravnar/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import base64
import uuid
from collections.abc import Callable
from typing import TYPE_CHECKING, Annotated, Any

import ag_ui.core
import ag_ui.encoder
import fastsse
import pydantic
from fastapi import Depends, HTTPException, Path, Query, status
Expand All @@ -15,7 +13,7 @@
from _ravnar import schema
from _ravnar.file_storage import FileHandler, WrappedMetadata
from _ravnar.observability import traced
from _ravnar.utils import as_awaitable, now
from _ravnar.utils import as_awaitable

tracer = trace.get_tracer(__name__)

Expand All @@ -26,8 +24,10 @@
from . import AgentHandler

ThreadsSortBy = str
RunsSortBy = str
else:
ThreadsSortBy = schema.create_str_literal("created_at", "updated_at", default="created_at")
ThreadsSortBy = schema.create_str_literal("created_at", default="created_at")
RunsSortBy = schema.create_str_literal("created_at", default="created_at")


def make_router(
Expand Down Expand Up @@ -70,46 +70,76 @@ async def get_thread(

@router.get("/{threadId}/messages")
async def get_thread_messages(
id: Annotated[str, Path(alias="threadId")],
thread_id: Annotated[str, Path(alias="threadId")],
user: schema.User = Depends(authenticated_user), # noqa: B008
) -> list[schema.AugmentedMessage]:
thread = await database.get_thread(user_id=user.id, id=id, with_messages=True)
return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(
thread.messages, from_attributes=True
_, _, messages = await database.get_thread_history(user_id=user.id, thread_id=thread_id, run_id=None)
return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True)

@router.get("/{threadId}/runs")
async def get_runs(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
pagination: Annotated[schema.Pagination[RunsSortBy], Query()],
) -> schema.Page[schema.Run]:
return schema.Page[schema.Run].model_validate(
await database.get_runs(user_id=user.id, thread_id=thread_id, pagination=pagination),
from_attributes=True,
)

@router.sse("/{threadId}/run", methods=["POST"], response_model=schema.Event, tags=["Runs"])
@router.get("/{threadId}/runs/{runId}")
async def get_run(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
run_id: Annotated[str, Path(alias="runId")],
) -> schema.Run:
return schema.Run.model_validate(await database.get_run(id=run_id, user_id=user.id), from_attributes=True)

@router.get("/{threadId}/runs/{runId}/messages")
async def get_run_messages(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
run_id: Annotated[str, Path(alias="runId")],
) -> list[schema.AugmentedMessage]:
_, _, messages = await database.get_thread_history(user_id=user.id, thread_id=thread_id, run_id=run_id)
return pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(messages, from_attributes=True)

@router.sse("/{threadId}/runs", methods=["POST"], response_model=schema.Event, tags=["Runs"])
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only intended BC break:

  • old: POST {threadId}/run
  • new: POST {threadId}/runs

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the impetus for the change? IMHO a POST endpoint should be a verb, not a noun.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runs are now a proper object. So this becomes a regular CRUD endpoint alongside

  • GET /api/threads/{threadId}/runs
  • GET /api/threads/{threadId}/runs/{runId}

async def create_run(
*,
user: schema.User = Depends(authenticated_user), # noqa: B008
thread_id: Annotated[str, Path(alias="threadId")],
data: schema.CreateRunData,
) -> fastsse.Response:
thread = await database.get_thread(user_id=user.id, id=thread_id, with_messages=True)

messages = pydantic.TypeAdapter(list[schema.AugmentedMessage]).validate_python(
thread.messages, from_attributes=True
thread, parent_run, parent_messages = await database.get_thread_history(
user_id=user.id, thread_id=thread_id, run_id=data.parent_run_id
)
messages.extend(data.messages)

await hydrate_files(messages, user=user, file_handler=file_handler)
augmented_messages_ta = pydantic.TypeAdapter(list[schema.AugmentedMessage])
augmented_messages = augmented_messages_ta.validate_python(parent_messages, from_attributes=True)
augmented_messages.extend(data.messages)

await hydrate_files(augmented_messages, user=user, file_handler=file_handler)

run_agent_input = ag_ui.core.RunAgentInput(
thread_id=thread.id,
run_id=str(uuid.uuid4()),
parent_run_id=None,
state=thread.state,
messages=[pydantic.TypeAdapter(ag_ui.core.Message).validate_python(m.model_dump()) for m in messages],
thread_id=thread_id,
run_id=data.id,
parent_run_id=parent_run.id if parent_run is not None else None,
state=parent_run.state if parent_run is not None else None,
messages=pydantic.TypeAdapter(list[ag_ui.core.Message]).validate_python(
augmented_messages_ta.dump_python(augmented_messages)
),
tools=data.tools,
context=data.context,
forwarded_props=data.forwarded_props,
)

async def callback(event_processor: EventProcessor) -> None:
with tracer.start_as_current_span("persist_run"):
thread.state, thread.messages = event_processor.extract()
thread.updated_at = now()
await database.update_thread(thread)
run = event_processor.extract(include_input_message_ids={m.id for m in data.messages})
await database.create_run(run)

return await agent_handler.run(thread.agent_id, run_agent_input, callback=callback)

Expand Down
8 changes: 1 addition & 7 deletions src/_ravnar/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,7 @@ async def run(
) -> fastsse.Response:
agent = self._get_agent(agent_id)

event_processor = EventProcessor(
thread_id=run_agent_input.thread_id,
run_id=run_agent_input.run_id,
parent_run_id=run_agent_input.parent_run_id,
state=run_agent_input.state,
messages=run_agent_input.messages,
)
event_processor = EventProcessor(run_agent_input=run_agent_input)

span = tracer.start_span("AgentHandler.run")
span.set_attribute("agent_id", agent_id)
Expand Down
Loading
Loading