Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,24 @@ async def load_artifact(
raise HTTPException(status_code=404, detail="Artifact not found")
return artifact

@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/metadata",
response_model=list[ArtifactVersion],
response_model_exclude_none=True,
)
async def list_artifact_versions_metadata(
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
) -> list[ArtifactVersion]:
return await self.artifact_service.list_artifact_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
)

@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
response_model_exclude_none=True,
Expand Down Expand Up @@ -1378,6 +1396,31 @@ async def save_artifact(
)
return artifact_version

@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}/metadata",
response_model=ArtifactVersion,
response_model_exclude_none=True,
)
async def get_artifact_version_metadata(
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
version_id: int,
) -> ArtifactVersion:
artifact_version = await self.artifact_service.get_artifact_version(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
version=version_id,
)
if not artifact_version:
raise HTTPException(
status_code=404, detail="Artifact version not found"
)
return artifact_version

@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
response_model_exclude_none=True,
Expand Down
36 changes: 36 additions & 0 deletions src/google/adk/cli/conformance/adk_web_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import httpx

from ...artifacts.base_artifact_service import ArtifactVersion
from ...events.event import Event
from ...sessions.session import Session
from ..adk_web_server import RunAgentRequest
Expand Down Expand Up @@ -265,3 +266,38 @@ async def run_agent(
yield Event.model_validate(event_data)
else:
logger.debug("Non data line received: %s", line)

async def get_artifact_version_metadata(
self,
*,
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
version: int,
) -> ArtifactVersion:
"""Retrieve metadata for a specific artifact version."""
async with self._get_client() as client:
response = await client.get((
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
f"/artifacts/{artifact_name}/versions/{version}/metadata"
))
response.raise_for_status()
return ArtifactVersion.model_validate(response.json())

async def list_artifact_versions_metadata(
self,
*,
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
) -> list[ArtifactVersion]:
"""List metadata for all versions of an artifact."""
async with self._get_client() as client:
response = await client.get((
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
f"/artifacts/{artifact_name}/versions/metadata"
))
response.raise_for_status()
return [ArtifactVersion.model_validate(item) for item in response.json()]
79 changes: 79 additions & 0 deletions tests/unittests/cli/conformance/test_adk_web_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest.mock import MagicMock
from unittest.mock import patch

from google.adk.artifacts.base_artifact_service import ArtifactVersion
from google.adk.cli.adk_web_server import RunAgentRequest
from google.adk.cli.conformance.adk_web_server_client import AdkWebServerClient
from google.adk.events.event import Event
Expand Down Expand Up @@ -224,6 +225,84 @@ def mock_stream(*_args, **_kwargs):
assert events[1].invocation_id == "test_invocation_2"


@pytest.mark.asyncio
async def test_get_artifact_version_metadata():
client = AdkWebServerClient()
mock_response = MagicMock()
mock_response.json.return_value = {
"version": 2,
"canonicalUri": (
"artifact://apps/app/users/user/sessions/session/"
"artifacts/report/versions/2"
),
"customMetadata": {"foo": "bar"},
"createTime": 123.4,
"mimeType": "text/plain",
}

with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
mock_client_class.return_value = mock_client

metadata = await client.get_artifact_version_metadata(
app_name="app",
user_id="user",
session_id="session",
artifact_name="report",
version=2,
)

assert isinstance(metadata, ArtifactVersion)
assert metadata.version == 2
assert metadata.custom_metadata == {"foo": "bar"}
mock_client.get.assert_called_once_with(
"/apps/app/users/user/sessions/session/artifacts/report/versions/2/metadata"
)
mock_response.raise_for_status.assert_called_once()


@pytest.mark.asyncio
async def test_list_artifact_versions_metadata():
client = AdkWebServerClient()
mock_response = MagicMock()
mock_response.json.return_value = [
{
"version": 0,
"canonicalUri": "artifact://.../versions/0",
"customMetadata": {},
"createTime": 100.0,
},
{
"version": 1,
"canonicalUri": "artifact://.../versions/1",
"customMetadata": {"foo": "bar"},
"createTime": 200.0,
"mimeType": "application/json",
},
]

with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
mock_client_class.return_value = mock_client

metadata_list = await client.list_artifact_versions_metadata(
app_name="app",
user_id="user",
session_id="session",
artifact_name="report",
)

assert len(metadata_list) == 2
assert all(isinstance(item, ArtifactVersion) for item in metadata_list)
assert metadata_list[1].custom_metadata == {"foo": "bar"}
mock_client.get.assert_called_once_with(
"/apps/app/users/user/sessions/session/artifacts/report/versions/metadata"
)
mock_response.raise_for_status.assert_called_once()


@pytest.mark.asyncio
async def test_close():
client = AdkWebServerClient()
Expand Down
114 changes: 114 additions & 0 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,48 @@ async def save_artifact(
})
return version

def add_artifact(
self,
*,
app_name: str,
user_id: str,
session_id: str,
filename: str,
artifact: types.Part,
custom_metadata: Optional[dict[str, Any]] = None,
canonical_uri: Optional[str] = None,
mime_type: Optional[str] = None,
) -> int:
"""Synchronous helper for tests to add artifacts."""
key = _artifact_key(app_name, user_id, session_id, filename)
entries = artifacts.setdefault(key, [])
version = len(entries)
artifact_version = ArtifactVersion(
version=version,
canonical_uri=(
canonical_uri
or _canonical_uri(
app_name, user_id, session_id, filename, version
)
),
custom_metadata=custom_metadata or {},
)
if mime_type:
artifact_version.mime_type = mime_type
elif artifact.inline_data is not None:
artifact_version.mime_type = artifact.inline_data.mime_type
elif artifact.text is not None:
artifact_version.mime_type = "text/plain"
elif artifact.file_data is not None:
artifact_version.mime_type = artifact.file_data.mime_type

entries.append({
"version": version,
"artifact": artifact,
"metadata": artifact_version,
})
return version

async def load_artifact(
self, app_name, user_id, session_id, filename, version=None
):
Expand Down Expand Up @@ -318,6 +360,15 @@ async def list_versions(self, app_name, user_id, session_id, filename):
return []
return [entry["version"] for entry in artifacts[key]]

async def list_artifact_versions(
self, app_name, user_id, session_id, filename
):
"""List all artifact versions with metadata."""
key = _artifact_key(app_name, user_id, session_id, filename)
if key not in artifacts:
return []
return [entry["metadata"] for entry in artifacts[key]]

async def delete_artifact(self, app_name, user_id, session_id, filename):
"""Delete an artifact."""
key = _artifact_key(app_name, user_id, session_id, filename)
Expand Down Expand Up @@ -980,6 +1031,69 @@ def test_save_artifact_returns_500_on_unexpected_error(
assert response.json()["detail"] == "unexpected failure"


def test_get_artifact_version_metadata(
test_app, create_test_session, mock_artifact_service
):
"""Test retrieving metadata for a specific artifact version."""
info = create_test_session
mock_artifact_service.add_artifact(
app_name=info["app_name"],
user_id=info["user_id"],
session_id=info["session_id"],
filename="report.txt",
artifact=types.Part(text="hello"),
custom_metadata={"foo": "bar"},
mime_type="text/plain",
)

url = (
f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
f"{info['session_id']}/artifacts/report.txt/versions/0/metadata"
)
response = test_app.get(url)

assert response.status_code == 200
data = response.json()
assert data["version"] == 0
assert data["customMetadata"] == {"foo": "bar"}
assert data["mimeType"] == "text/plain"


def test_list_artifact_versions_metadata(
test_app, create_test_session, mock_artifact_service
):
"""Test listing metadata for all versions of an artifact."""
info = create_test_session
mock_artifact_service.add_artifact(
app_name=info["app_name"],
user_id=info["user_id"],
session_id=info["session_id"],
filename="report.txt",
artifact=types.Part(text="v0"),
)
mock_artifact_service.add_artifact(
app_name=info["app_name"],
user_id=info["user_id"],
session_id=info["session_id"],
filename="report.txt",
artifact=types.Part(text="v1"),
custom_metadata={"foo": "bar"},
)

url = (
f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
f"{info['session_id']}/artifacts/report.txt/versions/metadata"
)
response = test_app.get(url)

assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
assert data[1]["version"] == 1
assert data[1]["customMetadata"] == {"foo": "bar"}


def test_create_eval_set(test_app, test_session_info):
"""Test creating an eval set."""
url = f"/apps/{test_session_info['app_name']}/eval_sets/test_eval_set_id"
Expand Down