Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions aieng-eval-agents/aieng/agent_evals/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
from aieng.agent_evals.async_client_manager import AsyncClientManager
from aieng.agent_evals.configs import Configs
from aieng.agent_evals.progress import track_with_progress
from langfuse import Langfuse
from langfuse.api.resources.commons.types.observations_view import ObservationsView
from langfuse.api.resources.observations.types.observations_views import ObservationsViews
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor
from tenacity import retry, stop_after_attempt, wait_exponential


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
Expand Down Expand Up @@ -370,3 +374,122 @@ def _normalize_dataset_record(item: Any, record_number: int) -> dict[str, Any]:
"expected_output": item["expected_output"],
"metadata": metadata,
}


def report_usage_scores(
trace_id: str,
token_threshold: int = 0,
latency_threshold: int = 0,
cost_threshold: float = 0,
) -> None:
"""Report usage scores to Langfuse for a given trace ID.

WARNING: Due to the nature of the Langfuse API, this function may hang
while trying to fetch the observations.

Parameters
----------
trace_id: str
The ID of the trace to report the usage scores for.
token_threshold: int
The token threshold to report the score for.
if the token count is greater than the threshold, the score
will be reported as 0.
Optional, default to 0 (no reporting).
latency_threshold: int
The latency threshold in seconds to report the score for.
if the latency is greater than the threshold, the score
will be reported as 0.
Optional, default to 0 (no reporting).
cost_threshold: float
The cost threshold to report the score for.
if the cost is greater than the threshold, the score
will be reported as 0.
Optional, default to 0 (no reporting).
"""
langfuse_client = AsyncClientManager.get_instance().langfuse_client
observations = _get_observations_with_retry(trace_id, langfuse_client)

if token_threshold > 0:
total_tokens = sum(_obs_attr(observation, "totalTokens") for observation in observations.data)
if total_tokens <= token_threshold:
score = 1
comment = "Token count is less than or equal to the threshold."
else:
score = 0
comment = "Token count is greater than the threshold."

logger.info("Reporting score for token count")
langfuse_client.create_score(
name="Token Count",
value=score,
trace_id=trace_id,
comment=comment,
metadata={
"total_tokens": total_tokens,
"token_threshold": token_threshold,
},
)

if latency_threshold > 0:
total_latency = sum(_obs_attr(observation, "latency") for observation in observations.data)
if total_latency <= latency_threshold:
score = 1
comment = "Latency is less than or equal to the threshold."
else:
score = 0
comment = "Latency is greater than the threshold."

logger.info("Reporting score for latency")
langfuse_client.create_score(
name="Latency",
value=score,
trace_id=trace_id,
comment=comment,
metadata={
"total_latency": total_latency,
"latency_threshold": latency_threshold,
},
)

if cost_threshold > 0:
total_cost = sum(_obs_attr(observation, "calculated_total_cost") for observation in observations.data)
if total_cost <= cost_threshold:
score = 1
comment = "Cost is less than or equal to the threshold."
else:
score = 0
comment = "Cost is greater than the threshold."

logger.info("Reporting score for cost")
langfuse_client.create_score(
name="Cost",
value=score,
trace_id=trace_id,
comment=comment,
metadata={
"total_cost": total_cost,
"cost_threshold": cost_threshold,
},
)

langfuse_client.flush()


def _obs_attr(observation: ObservationsView, attribute: str) -> Any:
"""Get the value of an attribute from an observation."""
attribute_value = getattr(observation, attribute)
if attribute_value == 0:
logger.error(f"Observation attribute value for {attribute} is 0")
return 0
if attribute_value is None:
logger.error(f"Observation attribute value for {attribute} is None")
return 0
return attribute_value


@retry(stop=stop_after_attempt(10), wait=wait_exponential(multiplier=1, min=5, max=15))
def _get_observations_with_retry(trace_id: str, langfuse_client: Langfuse) -> ObservationsViews:
"""Get the observations for a given trace ID with retry/backoff."""
logger.info(f"Getting observations for trace {trace_id}...")
return langfuse_client.api.observations.get_many(trace_id=trace_id, type="GENERATION")
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Example
-------
>>> from aieng.agent_evals.report_generation.evaluation import evaluate
>>> from aieng.agent_evals.report_generation.evaluation.offline import evaluate
>>> evaluate(
>>> dataset_name="OnlineRetailReportEval",
>>> reports_output_path=Path("reports/"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Functions to report online evaluation of the report generation agent to Langfuse."""

import logging

from aieng.agent_evals.async_client_manager import AsyncClientManager
from aieng.agent_evals.report_generation.agent import EventParser, EventType
from google.adk.events.event import Event


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)


def report_final_response_score(event: Event, string_match: str = "") -> None:
"""Report a score to Langfuse if the event is a final response.

The score will be reported as 1 if the final response is valid
and contains the string match. Otherwise, the score will be reported as 0.

This has to be called within the context of a Langfuse trace.

Parameters
----------
event : Event
The event to check.
string_match : str
The string to match in the final response.
Optional, default to empty string.

Raises
------
ValueError
If the event is not a final response.
"""
if not event.is_final_response():
raise ValueError("Event is not a final response")

langfuse_client = AsyncClientManager.get_instance().langfuse_client
trace_id = langfuse_client.get_current_trace_id()

if trace_id is None:
raise ValueError("Langfuse trace ID is None.")

parsed_events = EventParser.parse(event)
for parsed_event in parsed_events:
if parsed_event.type == EventType.FINAL_RESPONSE:
if string_match in parsed_event.text:
score = 1
comment = "Final response contains the string match."
else:
score = 0
comment = "Final response does not contains the string match."

logger.info("Reporting score for valid final response")
langfuse_client.create_score(
name="Valid Final Response",
value=score,
trace_id=trace_id,
comment=comment,
metadata={
"final_response": parsed_event.text,
"string_match": string_match,
},
)
langfuse_client.flush()
return

logger.info("Reporting score for invalid final response")
langfuse_client.create_score(
name="Valid Final Response",
value=0,
trace_id=trace_id,
comment="Final response not found in the event",
metadata={
"string_match": string_match,
},
)
langfuse_client.flush()
24 changes: 24 additions & 0 deletions implementations/report_generation/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@

import asyncio
import logging
import threading
from functools import partial
from typing import Any, AsyncGenerator

import click
import gradio as gr
from aieng.agent_evals.async_client_manager import AsyncClientManager
from aieng.agent_evals.langfuse import report_usage_scores
from aieng.agent_evals.report_generation.agent import get_report_generation_agent
from aieng.agent_evals.report_generation.evaluation.online import report_final_response_score
from aieng.agent_evals.report_generation.prompts import MAIN_AGENT_INSTRUCTIONS
from dotenv import load_dotenv
from google.adk.runners import Runner
Expand Down Expand Up @@ -65,6 +68,9 @@ async def agent_session_handler(
langfuse_project_name=get_langfuse_project_name() if enable_trace else None,
)

# Get the Langfuse client for online reporting
langfuse_client = AsyncClientManager.get_instance().langfuse_client

# Construct an in-memory session for the agent to maintain
# conversation history across multiple turns of a chat
# This makes it possible to ask follow-up questions that refer to
Expand Down Expand Up @@ -92,6 +98,24 @@ async def agent_session_handler(
if len(turn_messages) > 0:
yield turn_messages

if event.is_final_response():
# Report the final response evaluation to Langfuse
report_final_response_score(event, string_match="](gradio_api/file=")

# Run usage scoring in a thread so it doesn't block the UI
thread = threading.Thread(
target=report_usage_scores,
kwargs={
"trace_id": langfuse_client.get_current_trace_id(),
"token_threshold": 10000,
"latency_threshold": 60,
},
daemon=True,
)
thread.start()

langfuse_client.flush()


@click.command()
@click.option("--enable-trace", required=False, default=True, help="Whether to enable tracing with Langfuse.")
Expand Down
2 changes: 1 addition & 1 deletion implementations/report_generation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import asyncio

import click
from aieng.agent_evals.report_generation.evaluation import evaluate
from aieng.agent_evals.report_generation.evaluation.offline import evaluate
from dotenv import load_dotenv

from implementations.report_generation.data.langfuse_upload import DEFAULT_EVALUATION_DATASET_NAME
Expand Down
12 changes: 9 additions & 3 deletions implementations/report_generation/gradio_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility functions for the report generation agent."""

import json
import logging

from aieng.agent_evals.report_generation.agent import EventParser, EventType
Expand Down Expand Up @@ -40,11 +41,15 @@ def agent_event_to_gradio_messages(event: Event) -> list[ChatMessage]:
)
)
elif parsed_event.type == EventType.TOOL_CALL:
formatted_arguments = json.dumps(parsed_event.arguments, indent=2).replace("\\n", "\n")
output.append(
ChatMessage(
role="assistant",
content=f"```\n{parsed_event.arguments}\n```",
metadata={"title": f"🛠️ Used tool `{parsed_event.text}`"},
content=f"```\n{formatted_arguments}\n```",
metadata={
"title": f"🛠️ Used tool `{parsed_event.text}`",
"status": "done", # This makes it collapsed by default
},
)
)
elif parsed_event.type == EventType.THOUGHT:
Expand All @@ -56,10 +61,11 @@ def agent_event_to_gradio_messages(event: Event) -> list[ChatMessage]:
)
)
elif parsed_event.type == EventType.TOOL_RESPONSE:
formatted_arguments = json.dumps(parsed_event.arguments, indent=2).replace("\\n", "\n")
output.append(
ChatMessage(
role="assistant",
content=f"```\n{parsed_event.arguments}\n```",
content=f"```\n{formatted_arguments}\n```",
metadata={
"title": f"📝 Tool call output: `{parsed_event.text}`",
"status": "done", # This makes it collapsed by default
Expand Down