Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ AML_DB__DRIVER="sqlite"
AML_DB__DATABASE="implementations/aml_investigation/data/aml_transactions.db"
AML_DB__QUERY__MODE="ro"

# Report Generation Database Configuration
REPORT_GENERATION_DB__DRIVER="sqlite"
REPORT_GENERATION_DB__DATABASE="implementations/report_generation/data/OnlineRetail.db"
REPORT_GENERATION_DB__QUERY__MODE="ro"

# Report Generation (all optional, defaults are in implementations/report_generation/env_vars.py)
REPORT_GENERATION_OUTPUT_PATH="..."
REPORT_GENERATION_DB_PATH="..."
REPORT_GENERATION_LANGFUSE_PROJECT_NAME="..."
103 changes: 44 additions & 59 deletions aieng-eval-agents/aieng/agent_evals/async_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
"""

import logging
import sqlite3
from pathlib import Path
from typing import Any

from aieng.agent_evals.configs import Configs
from aieng.agent_evals.tools import ReadOnlySqlDatabase
from langfuse import Langfuse
from openai import AsyncOpenAI

Expand All @@ -18,46 +16,6 @@
logger = logging.getLogger(__name__)


class SQLiteConnection:
"""SQLite connection."""

def __init__(self, db_path: Path) -> None:
"""Initialize the SQLite connection.

Parameters
----------
db_path : Path
The path to the SQLite database.
"""
self.db_path = db_path
self.connection = sqlite3.connect(db_path)

def execute(self, query: str) -> list[Any] | str:
"""Execute a SQLite query.

Parameters
----------
query : str
The SQLite query to execute.

Returns
-------
list[Any] | str
The result of the query. Will return the result of
`execute(query).fetchall()`.
Returns a string with an error message if the query fails.
"""
try:
return self.connection.execute(query).fetchall()
except Exception as e:
logger.exception(f"Error executing query: {e}")
return [str(e)]

def close(self) -> None:
"""Close the SQLite connection."""
self.connection.close()


class AsyncClientManager:
"""Manages async client lifecycle with lazy initialization and cleanup.

Expand Down Expand Up @@ -105,7 +63,8 @@ def __init__(self, configs: Configs | None = None) -> None:
"""
self._configs: Configs | None = configs
self._openai_client: AsyncOpenAI | None = None
self._sqlite_connection: SQLiteConnection | None = None
self._aml_db: ReadOnlySqlDatabase | None = None
self._report_generation_db: ReadOnlySqlDatabase | None = None
self._langfuse_client: Langfuse | None = None
self._otel_instrumented: bool = False
self._initialized: bool = False
Expand Down Expand Up @@ -139,23 +98,45 @@ def openai_client(self) -> AsyncOpenAI:
self._initialized = True
return self._openai_client

def sqlite_connection(self, db_path: Path) -> SQLiteConnection:
"""Get or create SQLite session.
def report_generation_db(self, agent_name: str = "ReportGenerationAgent") -> ReadOnlySqlDatabase:
"""Get or create Report Generation database connection.

Parameters
----------
db_path : Path
The path to the SQLite database.
Returns
-------
ReadOnlySqlDatabase
The Report Generation database connection instance.
"""
if self._report_generation_db is None:
if self.configs.report_generation_db is None:
raise ValueError("Report Generation database configuration is missing.")

self._report_generation_db = ReadOnlySqlDatabase(
connection_uri=self.configs.report_generation_db.build_uri(),
agent_name=agent_name,
)
self._initialized = True

return self._report_generation_db

def aml_db(self, agent_name: str = "FraudInvestigationAnalyst") -> ReadOnlySqlDatabase:
"""Get or create AML database connection.

Returns
-------
SQLiteConnection
The SQLite connection instance.
ReadOnlySqlDatabase
The Report Generation database connection instance.
"""
if self._sqlite_connection is None or self._sqlite_connection.db_path != db_path:
self._sqlite_connection = SQLiteConnection(db_path)
if self._aml_db is None:
if self.configs.aml_db is None:
raise ValueError("AML database configuration is missing.")

self._aml_db = ReadOnlySqlDatabase(
connection_uri=self.configs.aml_db.build_uri(),
agent_name=agent_name,
)
self._initialized = True
return self._sqlite_connection

return self._aml_db

@property
def langfuse_client(self) -> Langfuse:
Expand Down Expand Up @@ -202,16 +183,20 @@ def otel_instrumented(self, value: bool) -> None:
async def close(self) -> None:
"""Close all initialized async clients.

This method closes the OpenAI client, SQLite connection, and Langfuse
This method closes the OpenAI client, database connections, and Langfuse
client if they have been initialized.
"""
if self._openai_client is not None:
await self._openai_client.close()
self._openai_client = None

if self._sqlite_connection is not None:
self._sqlite_connection.close()
self._sqlite_connection = None
if self._aml_db is not None:
self._aml_db.close()
self._aml_db = None

if self._report_generation_db is not None:
self._report_generation_db.close()
self._report_generation_db = None

if self._langfuse_client is not None:
self._langfuse_client.flush()
Expand Down
5 changes: 5 additions & 0 deletions aieng-eval-agents/aieng/agent_evals/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class Configs(BaseSettings):
description="Anti-Money Laundering database configuration. Used by the Fraud Investigation Agent.",
)

report_generation_db: DatabaseConfig | None = Field(
default=None,
description="Database configuration for the the Report Generation Agent.",
)

# === Core LLM Settings ===
openai_base_url: str = Field(
default="https://generativelanguage.googleapis.com/v1beta/openai/",
Expand Down
8 changes: 7 additions & 1 deletion aieng-eval-agents/aieng/agent_evals/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ def set_up_langfuse_otlp_env_vars():
"""
configs = Configs()

langfuse_key = f"{configs.langfuse_public_key}:{configs.langfuse_secret_key}".encode()
if configs.langfuse_secret_key:
langfuse_auth_key = configs.langfuse_secret_key.get_secret_value()
else:
logger.error("Langfuse secret key is not set. Monitoring may not be enabled.")
langfuse_auth_key = ""

langfuse_key = f"{configs.langfuse_public_key}:{langfuse_auth_key}".encode()
langfuse_auth = base64.b64encode(langfuse_key).decode()

os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = configs.langfuse_host + "/api/public/otel"
Expand Down
10 changes: 5 additions & 5 deletions aieng-eval-agents/aieng/agent_evals/report_generation/agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""
Definitions for the the report generation agent.

The database connection to the report generation database is obtained
from the environment variable `REPORT_GENERATION_DB__DATABASE`.

Example
-------
>>> from aieng.agent_evals.report_generation.agent import get_report_generation_agent
>>> from aieng.agent_evals.report_generation.prompts import MAIN_AGENT_INSTRUCTIONS
>>> agent = get_report_generation_agent(
>>> instructions=MAIN_AGENT_INSTRUCTIONS,
>>> sqlite_db_path=Path("data/OnlineRetail.db"),
>>> reports_output_path=Path("reports/"),
>>> langfuse_project_name="Report Generation",
>>> )
Expand All @@ -32,7 +34,6 @@

def get_report_generation_agent(
instructions: str,
sqlite_db_path: Path,
reports_output_path: Path,
langfuse_project_name: str | None,
) -> Agent:
Expand All @@ -43,8 +44,6 @@ def get_report_generation_agent(
----------
instructions : str
The instructions for the agent.
sqlite_db_path : Path
The path to the SQLite database.
reports_output_path : Path
The path to the reports output directory.
langfuse_project_name : str | None
Expand All @@ -69,7 +68,8 @@ def get_report_generation_agent(
model=client_manager.configs.default_worker_model,
instruction=instructions,
tools=[
client_manager.sqlite_connection(sqlite_db_path).execute,
client_manager.report_generation_db().execute,
client_manager.report_generation_db().get_schema_info,
report_file_writer.write_xlsx,
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
>>> from aieng.agent_evals.report_generation.evaluation import evaluate
>>> evaluate(
>>> dataset_name="OnlineRetailReportEval",
>>> sqlite_db_path=Path("data/OnlineRetail.db"),
>>> reports_output_path=Path("reports/"),
>>> langfuse_project_name="Report Generation",
>>> )
Expand Down Expand Up @@ -60,18 +59,18 @@ class EvaluatorResponse(BaseModel):

async def evaluate(
dataset_name: str,
sqlite_db_path: Path,
reports_output_path: Path,
langfuse_project_name: str,
) -> None:
"""Evaluate the report generation agent against a Langfuse dataset.

The database connection to the report generation database is obtained
from the environment variable `REPORT_GENERATION_DB__DATABASE`.

Parameters
----------
dataset_name : str
Name of the Langfuse dataset to evaluate against.
sqlite_db_path : Path
The path to the SQLite database.
reports_output_path : Path
The path to the reports output directory.
langfuse_project_name : str
Expand All @@ -88,7 +87,6 @@ async def evaluate(
# We need this task so we can pass parameters to the agent, since
# the agent has to be instantiated inside the task function
report_generation_task = ReportGenerationTask(
sqlite_db_path=sqlite_db_path,
reports_output_path=reports_output_path,
langfuse_project_name=langfuse_project_name,
)
Expand Down Expand Up @@ -118,22 +116,18 @@ class ReportGenerationTask:

def __init__(
self,
sqlite_db_path: Path,
reports_output_path: Path,
langfuse_project_name: str,
):
"""Initialize the task for an report generation agent evaluation.

Parameters
----------
sqlite_db_path : Path
The path to the SQLite database.
reports_output_path : Path
The path to the reports output directory.
langfuse_project_name : str
The name of the Langfuse project to use for tracing.
"""
self.sqlite_db_path = sqlite_db_path
self.reports_output_path = reports_output_path
self.langfuse_project_name = langfuse_project_name

Expand All @@ -154,7 +148,6 @@ async def run(self, *, item: LocalExperimentItem | DatasetItemClient, **kwargs:
# Run the report generation agent
report_generation_agent = get_report_generation_agent(
instructions=MAIN_AGENT_INSTRUCTIONS,
sqlite_db_path=self.sqlite_db_path,
reports_output_path=self.reports_output_path,
langfuse_project_name=self.langfuse_project_name,
)
Expand Down
2 changes: 0 additions & 2 deletions aieng-eval-agents/aieng/agent_evals/tools/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,6 @@ def execute(self, query: str) -> str:
------
PermissionError
If the query attempts to perform a write operation.
Exception
For any database execution errors.

Notes
-----
Expand Down
15 changes: 5 additions & 10 deletions implementations/aml_investigation/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,14 @@
def _get_db() -> ReadOnlySqlDatabase:
"""Lazily construct the read-only database tool from environment configuration."""
client_manager = AsyncClientManager().get_instance()
if client_manager.configs.aml_db is None:
raise ValueError("AML database configuration is missing.")

return ReadOnlySqlDatabase(
connection_uri=client_manager.configs.aml_db.build_uri(),
agent_name="FraudInvestigationAnalyst",
)
return client_manager.aml_db()


def _try_close_db() -> None:
async def _try_close_db() -> None:
"""Close the lazily initialized database tool if it was created."""
if _get_db.cache_info().currsize:
_get_db().close()
client_manager = AsyncClientManager().get_instance()
await client_manager.close()
_get_db.cache_clear()


Expand Down Expand Up @@ -294,7 +289,7 @@ async def _main() -> None:
logger.info(" TP=%d FP=%d", tp, fp)
logger.info(" FN=%d TN=%d", fn, tn)
finally:
_try_close_db()
await _try_close_db()


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions implementations/report_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ uv run --env-file .env python -m implementations.report_generation.data.import_o

Replace `<path_to_the_csv_file>` with the path the dataset's .CSV file is saved in your machine.

***NOTE:*** You can configure the location the database is saved by setting the path to
an environment variable named `REPORT_GENERATION_DB_PATH`.
***NOTE:*** The location the database is saved is determined by an environment variable
named `REPORT_GENERATION_DB__DATABASE`.

## Running the Demo UI

Expand Down
Loading