Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 101 additions & 12 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,18 @@ class AddSessionToEvalSetRequest(common.BaseModel):


class RunEvalRequest(common.BaseModel):
eval_ids: list[str] # if empty, then all evals in the eval set are run.
eval_ids: list[str] = Field(
deprecated=True,
default_factory=list,
description="This field is deprecated, use eval_case_ids instead.",
)
eval_case_ids: list[str] = Field(
default_factory=list,
description=(
"List of eval case ids to evaluate. if empty, then all eval cases in"
" the eval set are run."
),
)
eval_metrics: list[EvalMetric]


Expand All @@ -195,6 +206,10 @@ class RunEvalResult(common.BaseModel):
session_id: str


class RunEvalResponse(common.BaseModel):
run_eval_results: list[RunEvalResult]


class GetEventGraphResult(common.BaseModel):
dot_src: str

Expand All @@ -207,6 +222,22 @@ class ListEvalSetsResponse(common.BaseModel):
eval_set_ids: list[str]


class EvalResult(EvalSetResult):
"""This class has no field intentionally.

The goal here is to just give a new name to the class to align with the API
endpoint.
"""


class ListEvalResultsResponse(common.BaseModel):
eval_result_ids: list[str]


class ListMetricsInfoResponse(common.BaseModel):
metrics_info: list[MetricInfo]


class AdkWebServer:
"""Helper class for setting up and running the ADK web server on FastAPI.

Expand Down Expand Up @@ -690,14 +721,30 @@ async def delete_eval(
except NotFoundError as nfe:
raise HTTPException(status_code=404, detail=str(nfe)) from nfe

@deprecated(
"Please use run_eval instead. This will be removed in future releases."
)
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def run_eval(
async def run_eval_legacy(
app_name: str, eval_set_id: str, req: RunEvalRequest
) -> list[RunEvalResult]:
run_eval_response = await run_eval(
app_name=app_name, eval_set_id=eval_set_id, req=req
)
return run_eval_response.run_eval_results

@app.post(
"/apps/{app_name}/eval-sets/{eval_set_id}/run",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def run_eval(
app_name: str, eval_set_id: str, req: RunEvalRequest
) -> RunEvalResponse:
"""Runs an eval given the details in the eval request."""
# Create a mapping from eval set file to all the evals that needed to be
# run.
Expand Down Expand Up @@ -727,7 +774,7 @@ async def run_eval(
inference_request = InferenceRequest(
app_name=app_name,
eval_set_id=eval_set.eval_set_id,
eval_case_ids=req.eval_ids,
eval_case_ids=req.eval_case_ids or req.eval_ids,
inference_config=InferenceConfig(),
)
inference_results = await _collect_inferences(
Expand Down Expand Up @@ -760,18 +807,41 @@ async def run_eval(
)
)

return run_eval_results
return RunEvalResponse(run_eval_results=run_eval_results)

@app.get(
"/apps/{app_name}/eval_results/{eval_result_id}",
"/apps/{app_name}/eval-results/{eval_result_id}",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def get_eval_result(
app_name: str,
eval_result_id: str,
) -> EvalSetResult:
) -> EvalResult:
"""Gets the eval result for the given eval id."""
try:
eval_set_result = self.eval_set_results_manager.get_eval_set_result(
app_name, eval_result_id
)
return EvalResult(**eval_set_result.model_dump())
except ValueError as ve:
raise HTTPException(status_code=404, detail=str(ve)) from ve
except ValidationError as ve:
raise HTTPException(status_code=500, detail=str(ve)) from ve

@deprecated(
"Please use get_eval_result instead. This will be removed in future"
" releases."
)
@app.get(
"/apps/{app_name}/eval_results/{eval_result_id}",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def get_eval_result_legacy(
app_name: str,
eval_result_id: str,
) -> EvalSetResult:
try:
return self.eval_set_results_manager.get_eval_set_result(
app_name, eval_result_id
Expand All @@ -782,27 +852,46 @@ async def get_eval_result(
raise HTTPException(status_code=500, detail=str(ve)) from ve

@app.get(
"/apps/{app_name}/eval_results",
"/apps/{app_name}/eval-results",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def list_eval_results(app_name: str) -> list[str]:
async def list_eval_results(app_name: str) -> ListEvalResultsResponse:
"""Lists all eval results for the given app."""
return self.eval_set_results_manager.list_eval_set_results(app_name)
eval_result_ids = self.eval_set_results_manager.list_eval_set_results(
app_name
)
return ListEvalResultsResponse(eval_result_ids=eval_result_ids)

@deprecated(
"Please use list_eval_results instead. This will be removed in future"
" releases."
)
@app.get(
"/apps/{app_name}/eval_results",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def list_eval_results_legacy(app_name: str) -> list[str]:
list_eval_results_response = await list_eval_results(app_name)
return list_eval_results_response.eval_result_ids

@app.get(
"/apps/{app_name}/eval_metrics",
"/apps/{app_name}/metrics-info",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def list_eval_metrics(app_name: str) -> list[MetricInfo]:
async def list_metrics_info(app_name: str) -> ListMetricsInfoResponse:
"""Lists all eval metrics for the given app."""
try:
from ..evaluation.metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY

# Right now we ignore the app_name as eval metrics are not tied to the
# app_name, but they could be moving forward.
return DEFAULT_METRIC_EVALUATOR_REGISTRY.get_registered_metrics()
metrics_info = (
DEFAULT_METRIC_EVALUATOR_REGISTRY.get_registered_metrics()
)
return ListMetricsInfoResponse(metrics_info=metrics_info)
except ModuleNotFoundError as e:
logger.exception("%s\n%s", MISSING_EVAL_DEPENDENCIES_MESSAGE, e)
raise HTTPException(
Expand Down
14 changes: 8 additions & 6 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,18 +845,20 @@ def verify_eval_case_result(actual_eval_case_result):
assert data == [f"{info['app_name']}_test_eval_set_id_eval_result"]


def test_list_eval_metrics(test_app):
"""Test listing eval metrics."""
url = "/apps/test_app/eval_metrics"
def test_list_metrics_info(test_app):
"""Test listing metrics info."""
url = "/apps/test_app/metrics-info"
response = test_app.get(url)

# Verify the response
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
metrics_info_key = "metricsInfo"
assert metrics_info_key in data
assert isinstance(data[metrics_info_key], list)
# Add more assertions based on the expected metrics
assert len(data) > 0
for metric in data:
assert len(data[metrics_info_key]) > 0
for metric in data[metrics_info_key]:
assert "metricName" in metric
assert "description" in metric
assert "metricValueInfo" in metric
Expand Down
Loading