Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ class CreateEvalSetRequest(common.BaseModel):
eval_set: EvalSet


class ListEvalSetsResponse(common.BaseModel):
eval_set_ids: list[str]


class AdkWebServer:
"""Helper class for setting up and running the ADK web server on FastAPI.

Expand Down Expand Up @@ -512,18 +516,38 @@ async def create_eval_set_legacy(
)

@app.get(
"/apps/{app_name}/eval_sets",
"/apps/{app_name}/eval-sets",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def list_eval_sets(app_name: str) -> list[str]:
async def list_eval_sets(app_name: str) -> ListEvalSetsResponse:
"""Lists all eval sets for the given app."""
eval_sets = []
try:
return self.eval_sets_manager.list_eval_sets(app_name)
eval_sets = self.eval_sets_manager.list_eval_sets(app_name)
except NotFoundError as e:
logger.warning(e)
return []

return ListEvalSetsResponse(eval_set_ids=eval_sets)

@deprecated(
"Please use list_eval_sets instead. This will be removed in future"
" releases."
)
@app.get(
"/apps/{app_name}/eval_sets",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def list_eval_sets_legacy(app_name: str) -> list[str]:
list_eval_sets_response = await list_eval_sets(app_name)
return list_eval_sets_response.eval_set_ids

@app.post(
"/apps/{app_name}/eval-sets/{eval_set_id}/add-session",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
response_model_exclude_none=True,
Expand Down Expand Up @@ -583,6 +607,11 @@ async def list_evals_in_eval_set(

return sorted([x.eval_id for x in eval_set_data.eval_cases])

@app.get(
"/apps/{app_name}/eval-sets/{eval_set_id}/eval-cases/{eval_case_id}",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
@app.get(
"/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}",
response_model_exclude_none=True,
Expand All @@ -606,6 +635,11 @@ async def get_eval(
),
)

@app.put(
"/apps/{app_name}/eval-sets/{eval_set_id}/eval-cases/{eval_case_id}",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
@app.put(
"/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}",
response_model_exclude_none=True,
Expand Down Expand Up @@ -638,6 +672,10 @@ async def update_eval(
except NotFoundError as nfe:
raise HTTPException(status_code=404, detail=str(nfe)) from nfe

@app.delete(
"/apps/{app_name}/eval-sets/{eval_set_id}/eval-cases/{eval_case_id}",
tags=[TAG_EVALUATION],
)
@app.delete(
"/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}",
tags=[TAG_EVALUATION],
Expand Down