Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions contributing/samples/bigquery/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@
# Define an appropriate credential type
CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2

# Define an appropriate application name
BIGQUERY_AGENT_NAME = "adk_sample_bigquery_agent"


# Define BigQuery tool config with write mode set to allowed. Note that this is
# only to demonstrate the full capability of the BigQuery tools. In production
# you may want to change to BLOCKED (default write mode, effectively makes the
# tool read-only) or PROTECTED (only allows writes in the anonymous dataset of a
# BigQuery session) write mode.
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
tool_config = BigQueryToolConfig(
write_mode=WriteMode.ALLOWED, application_name=BIGQUERY_AGENT_NAME
)

if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
# Initiaze the tools to do interactive OAuth
Expand Down Expand Up @@ -64,7 +69,7 @@
# debug CLI
root_agent = LlmAgent(
model="gemini-2.0-flash",
name="bigquery_agent",
name=BIGQUERY_AGENT_NAME,
description=(
"Agent to answer questions about BigQuery data and models and execute"
" SQL queries."
Expand Down
9 changes: 7 additions & 2 deletions src/google/adk/tools/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@


def get_bigquery_client(
*, project: Optional[str], credentials: Credentials
*,
project: Optional[str],
credentials: Credentials,
user_agent: Optional[str] = None,
) -> bigquery.Client:
"""Get a BigQuery client."""

client_info = google.api_core.client_info.ClientInfo(user_agent=USER_AGENT)
user_agent = f"{USER_AGENT} {user_agent}" if user_agent else USER_AGENT

client_info = google.api_core.client_info.ClientInfo(user_agent=user_agent)

bigquery_client = bigquery.Client(
project=project, credentials=credentials, client_info=client_info
Expand Down
21 changes: 20 additions & 1 deletion src/google/adk/tools/bigquery/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from __future__ import annotations

from enum import Enum
from typing import Optional

from pydantic import BaseModel
from pydantic import field_validator

from ...utils.feature_decorator import experimental

Expand Down Expand Up @@ -58,4 +60,21 @@ class BigQueryToolConfig(BaseModel):
max_query_result_rows: int = 50
"""Maximum number of rows to return from a query.

By default, the query result will be limited to 50 rows."""
By default, the query result will be limited to 50 rows.
"""

application_name: Optional[str] = None
"""Name of the application using the BigQuery tools.

By default, no particular application name will be set in the BigQuery
interaction. But if the the tool user (agent builder) wants to differentiate
their application/agent for tracking or support purpose, they can set this field.
"""

@field_validator('application_name')
@classmethod
def validate_application_name(cls, v):
"""Validate the application name."""
if v and ' ' in v:
raise ValueError('Application name should not contain spaces.')
return v
37 changes: 29 additions & 8 deletions src/google/adk/tools/bigquery/metadata_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
from google.cloud import bigquery

from . import client
from .config import BigQueryToolConfig


def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]:
def list_dataset_ids(
project_id: str, credentials: Credentials, settings: BigQueryToolConfig
) -> list[str]:
"""List BigQuery dataset ids in a Google Cloud project.

Args:
Expand All @@ -45,7 +48,9 @@ def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]:
"""
try:
bq_client = client.get_bigquery_client(
project=project_id, credentials=credentials
project=project_id,
credentials=credentials,
user_agent=settings.application_name,
)

datasets = []
Expand All @@ -60,7 +65,10 @@ def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]:


def get_dataset_info(
project_id: str, dataset_id: str, credentials: Credentials
project_id: str,
dataset_id: str,
credentials: Credentials,
settings: BigQueryToolConfig,
) -> dict:
"""Get metadata information about a BigQuery dataset.

Expand Down Expand Up @@ -111,7 +119,9 @@ def get_dataset_info(
"""
try:
bq_client = client.get_bigquery_client(
project=project_id, credentials=credentials
project=project_id,
credentials=credentials,
user_agent=settings.application_name,
)
dataset = bq_client.get_dataset(
bigquery.DatasetReference(project_id, dataset_id)
Expand All @@ -125,7 +135,10 @@ def get_dataset_info(


def list_table_ids(
project_id: str, dataset_id: str, credentials: Credentials
project_id: str,
dataset_id: str,
credentials: Credentials,
settings: BigQueryToolConfig,
) -> list[str]:
"""List table ids in a BigQuery dataset.

Expand All @@ -144,7 +157,9 @@ def list_table_ids(
"""
try:
bq_client = client.get_bigquery_client(
project=project_id, credentials=credentials
project=project_id,
credentials=credentials,
user_agent=settings.application_name,
)

tables = []
Expand All @@ -161,7 +176,11 @@ def list_table_ids(


def get_table_info(
project_id: str, dataset_id: str, table_id: str, credentials: Credentials
project_id: str,
dataset_id: str,
table_id: str,
credentials: Credentials,
settings: BigQueryToolConfig,
) -> dict:
"""Get metadata information about a BigQuery table.

Expand Down Expand Up @@ -260,7 +279,9 @@ def get_table_info(
"""
try:
bq_client = client.get_bigquery_client(
project=project_id, credentials=credentials
project=project_id,
credentials=credentials,
user_agent=settings.application_name,
)
return bq_client.get_table(
bigquery.TableReference(
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/tools/bigquery/query_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def execute_sql(
try:
# Get BigQuery client
bq_client = client.get_bigquery_client(
project=project_id, credentials=credentials
project=project_id,
credentials=credentials,
user_agent=settings.application_name,
)

# BigQuery connection properties where applicable
Expand Down
38 changes: 32 additions & 6 deletions tests/unittests/tools/bigquery/test_bigquery_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from __future__ import annotations

import os
import re
from unittest import mock

import google.adk
from google.adk.tools.bigquery.client import get_bigquery_client
from google.auth.exceptions import DefaultCredentialsError
from google.oauth2.credentials import Credentials
Expand Down Expand Up @@ -109,8 +109,8 @@ def test_bigquery_client_project_set_with_env():
assert client.project == "test-gcp-project"


def test_bigquery_client_user_agent():
"""Test BigQuery client user agent."""
def test_bigquery_client_user_agent_default():
"""Test BigQuery client default user agent."""
with mock.patch(
"google.cloud.bigquery.client.Connection", autospec=True
) as mock_connection:
Expand All @@ -123,7 +123,33 @@ def test_bigquery_client_user_agent():
# Verify that the tracking user agent was set
client_info_arg = mock_connection.call_args[1].get("client_info")
assert client_info_arg is not None
assert re.search(
r"adk-bigquery-tool google-adk/([0-9A-Za-z._\-+/]+)",
client_info_arg.user_agent,
expected_user_agents = {
"adk-bigquery-tool",
f"google-adk/{google.adk.__version__}",
}
actual_user_agents = set(client_info_arg.user_agent.split())
assert expected_user_agents.issubset(actual_user_agents)


def test_bigquery_client_user_agent_custom():
"""Test BigQuery client custom user agent."""
with mock.patch(
"google.cloud.bigquery.client.Connection", autospec=True
) as mock_connection:
# Trigger the BigQuery client creation
get_bigquery_client(
project="test-gcp-project",
credentials=mock.create_autospec(Credentials, instance=True),
user_agent="custom_user_agent",
)

# Verify that the tracking user agent was set
client_info_arg = mock_connection.call_args[1].get("client_info")
assert client_info_arg is not None
expected_user_agents = {
"adk-bigquery-tool",
f"google-adk/{google.adk.__version__}",
"custom_user_agent",
}
actual_user_agents = set(client_info_arg.user_agent.split())
assert expected_user_agents.issubset(actual_user_agents)
Loading