Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 281 additions & 17 deletions tests/unit/vertex_adk/test_agent_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,20 +1066,21 @@ def update_agent_engine_mock():

@pytest.mark.usefixtures("google_auth_mock")
class TestAgentEngines:
def setup_method(self):
importlib.reload(initializer)
importlib.reload(aiplatform)
aiplatform.init(

def setup_method(self):
importlib.reload(initializer)
importlib.reload(aiplatform)
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
staging_bucket=_TEST_STAGING_BUCKET,
)

def teardown_method(self):
initializer.global_pool.shutdown(wait=True)
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@pytest.mark.parametrize(
@pytest.mark.parametrize(
"env_vars,expected_env_vars",
[
({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "unspecified"}),
Expand All @@ -1101,7 +1102,7 @@ def teardown_method(self):
),
],
)
def test_create_default_telemetry_enablement(
def test_create_default_telemetry_enablement(
self,
create_agent_engine_mock: mock.Mock,
cloud_storage_create_bucket_mock: mock.Mock,
Expand All @@ -1111,18 +1112,18 @@ def test_create_default_telemetry_enablement(
env_vars: dict[str, str],
expected_env_vars: dict[str, str],
):
agent_engines.create(
agent_engines.create(
agent_engine=agent_engines.AdkApp(agent=_TEST_AGENT),
env_vars=env_vars,
)
deployment_spec = create_agent_engine_mock.call_args.kwargs[
deployment_spec = create_agent_engine_mock.call_args.kwargs[
"reasoning_engine"
].spec.deployment_spec
assert _utils.to_dict(deployment_spec)["env"] == [
assert _utils.to_dict(deployment_spec)["env"] == [
{"name": key, "value": value} for key, value in expected_env_vars.items()
]

@pytest.mark.parametrize(
@pytest.mark.parametrize(
"env_vars,expected_env_vars",
[
({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "unspecified"}),
Expand All @@ -1144,7 +1145,7 @@ def test_create_default_telemetry_enablement(
),
],
)
def test_update_default_telemetry_enablement(
def test_update_default_telemetry_enablement(
self,
update_agent_engine_mock: mock.Mock,
cloud_storage_create_bucket_mock: mock.Mock,
Expand All @@ -1155,15 +1156,278 @@ def test_update_default_telemetry_enablement(
env_vars: dict[str, str],
expected_env_vars: dict[str, str],
):
agent_engines.update(
agent_engines.update(
resource_name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
description="foobar", # avoid "At least one of ... must be specified" errors.
env_vars=env_vars,
)
update_agent_engine_mock.assert_called_once()
deployment_spec = update_agent_engine_mock.call_args.kwargs[
update_agent_engine_mock.assert_called_once()
deployment_spec = update_agent_engine_mock.call_args.kwargs[
"request"
].reasoning_engine.spec.deployment_spec
assert _utils.to_dict(deployment_spec)["env"] == [
assert _utils.to_dict(deployment_spec)["env"] == [
{"name": key, "value": value} for key, value in expected_env_vars.items()
]


class TestAdkAppMtls:
"""Test cases for mTLS functionality in AdkApp."""

def test_use_client_cert_effective_with_should_use_client_cert(self):
"""Verifies that it respects the google-auth mTLS enablement check."""
with mock.patch.object(
adk_template.mtls,
"should_use_client_cert",
return_value=True,
create=True,
):
assert adk_template._use_client_cert_effective() is True

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"})
def test_use_client_cert_effective_with_env_var_true(self):
"""Verifies that it falls back to the environment variable if google-auth check fails."""
with mock.patch.object(
adk_template.mtls,
"should_use_client_cert",
side_effect=AttributeError,
create=True,
):
assert adk_template._use_client_cert_effective() is True

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"})
def test_use_client_cert_effective_with_env_var_false(self):
"""Verifies that it respects the environment variable being set to false."""
with mock.patch.object(
adk_template.mtls,
"should_use_client_cert",
side_effect=AttributeError,
create=True,
):
assert adk_template._use_client_cert_effective() is False

def test_get_api_endpoint_default(self):
"""Verifies the default telemetry endpoint is returned when no mTLS is configured."""
assert (
adk_template._get_api_endpoint()
== adk_template._DEFAULT_TELEMETRY_ENDPOINT
)

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"})
def test_get_api_endpoint_always_with_cert(self):
"""Verifies the mTLS endpoint is used when forced and a certificate is available."""
assert (
adk_template._get_api_endpoint(client_cert_source=b"cert")
== adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT
)

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"})
def test_get_api_endpoint_always_no_cert(self):
"""Verifies it falls back to regular endpoint even if forced if no certificate is provided."""
assert (
adk_template._get_api_endpoint()
== adk_template._DEFAULT_TELEMETRY_ENDPOINT
)

@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"})
def test_get_api_endpoint_never(self):
"""Verifies the regular endpoint is used when mTLS is explicitly disabled."""
assert (
adk_template._get_api_endpoint(client_cert_source=b"cert")
== adk_template._DEFAULT_TELEMETRY_ENDPOINT
)

@mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT))
@mock.patch.object(adk_template.requests_auth, "AuthorizedSession")
@mock.patch(
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter"
)
def test_default_instrumentor_builder_with_mtls(
self,
mock_exporter,
mock_session_cls,
mock_auth_default,
):
"""Integration test for the instrumentor builder with mTLS enabled."""
# Mocking to enable mTLS
with mock.patch.object(
adk_template, "_use_client_cert_effective", return_value=True
):
with mock.patch.object(
adk_template.mtls, "has_default_client_cert_source", return_value=True
):
with mock.patch.object(
adk_template.mtls,
"default_client_cert_source",
return_value=lambda: b"cert",
):
adk_template._default_instrumentor_builder(
_TEST_PROJECT_ID, enable_tracing=True
)

# Verify the session was configured for mTLS
mock_session_cls.return_value.configure_mtls_channel.assert_called_once()
# Verify the exporter was initialized with the mTLS endpoint
mock_exporter.assert_called_once()
assert (
mock_exporter.call_args.kwargs["endpoint"]
== adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT
)

@mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT))
@mock.patch.object(adk_template.requests_auth, "AuthorizedSession")
def test_warn_if_telemetry_api_disabled_with_mtls(
self,
mock_session_cls,
mock_auth_default,
):
"""Integration test for the telemetry API check with mTLS enabled."""
mock_session = mock_session_cls.return_value
mock_session.post.return_value = mock.Mock(text="")

# Mocking to enable mTLS
with mock.patch.object(
adk_template, "_use_client_cert_effective", return_value=True
):
with mock.patch.object(
adk_template.mtls, "has_default_client_cert_source", return_value=True
):
with mock.patch.object(
adk_template.mtls,
"default_client_cert_source",
return_value=lambda: b"cert",
):
adk_template._warn_if_telemetry_api_disabled()

# Verify mTLS channel was configured for the check request
mock_session.configure_mtls_channel.assert_called_once()
# Verify the check was performed against the mTLS endpoint
mock_session.post.assert_called_once_with(
adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT, data=None
)

@mock.patch.dict(
os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "invalid_value"}
)
def test_get_api_endpoint_invalid_env(self):
"""Verifies it defaults to AUTO and warns on invalid env var."""
with mock.patch.object(adk_template, "_warn") as mock_warn:
assert (
adk_template._get_api_endpoint()
== adk_template._DEFAULT_TELEMETRY_ENDPOINT
)
mock_warn.assert_called_once()

@mock.patch.dict(
os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "not_a_bool"}
)
def test_use_client_cert_effective_invalid_env(self):
"""Verifies it warns on invalid boolean env var."""
with mock.patch.object(
adk_template.mtls,
"should_use_client_cert",
side_effect=AttributeError,
create=True,
):
with mock.patch.object(adk_template, "_warn") as mock_warn:
assert adk_template._use_client_cert_effective() is False
mock_warn.assert_called_once()

def test_use_client_cert_effective_with_should_use_client_cert_false(self):
"""Verifies that it respects google-auth returning False for mTLS."""
with mock.patch.object(
adk_template.mtls,
"should_use_client_cert",
return_value=False,
create=True,
):
assert adk_template._use_client_cert_effective() is False

def test_get_api_endpoint_auto_with_cert(self):
"""Verifies the mTLS endpoint is used in AUTO mode when a cert is available."""
# AUTO is the default, so we just pass a cert
assert (
adk_template._get_api_endpoint(client_cert_source=b"cert")
== adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT
)

@mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT))
@mock.patch.object(adk_template.requests_auth, "AuthorizedSession")
@mock.patch(
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter"
)
def test_default_instrumentor_builder_no_mtls(
self,
mock_exporter,
mock_session_cls,
mock_auth_default,
):
"""Integration test for the instrumentor builder with mTLS disabled."""
with mock.patch.object(
adk_template, "_use_client_cert_effective", return_value=False
):
adk_template._default_instrumentor_builder(
_TEST_PROJECT_ID, enable_tracing=True
)

# Verify mTLS channel was NOT configured
mock_session_cls.return_value.configure_mtls_channel.assert_not_called()
# Verify the exporter was initialized with the regular endpoint
mock_exporter.assert_called_once()
assert (
mock_exporter.call_args.kwargs["endpoint"]
== adk_template._DEFAULT_TELEMETRY_ENDPOINT
)

@mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT))
@mock.patch.object(adk_template.requests_auth, "AuthorizedSession")
def test_warn_if_telemetry_api_disabled_no_mtls(
self,
mock_session_cls,
mock_auth_default,
):
"""Integration test for the telemetry API check with mTLS disabled."""
mock_session = mock_session_cls.return_value
mock_session.post.return_value = mock.Mock(text="")

with mock.patch.object(
adk_template, "_use_client_cert_effective", return_value=False
):
adk_template._warn_if_telemetry_api_disabled()

# Verify mTLS channel was NOT configured
mock_session.configure_mtls_channel.assert_not_called()
# Verify the check was performed against the regular endpoint
mock_session.post.assert_called_once_with(
adk_template._DEFAULT_TELEMETRY_ENDPOINT, data=None
)

@mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT))
@mock.patch.object(adk_template.requests_auth, "AuthorizedSession")
@mock.patch(
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter"
)
def test_default_instrumentor_builder_mtls_no_cert_source(
self,
mock_exporter,
mock_session_cls,
mock_auth_default,
):
"""Tests that it falls back to regular endpoint if mTLS is on but no cert is found."""
with mock.patch.object(
adk_template, "_use_client_cert_effective", return_value=True
):
with mock.patch.object(
adk_template.mtls,
"has_default_client_cert_source",
return_value=False,
):
adk_template._default_instrumentor_builder(
_TEST_PROJECT_ID, enable_tracing=True
)

# Channel is configured, but endpoint remains default due to missing cert source
mock_session_cls.return_value.configure_mtls_channel.assert_called_once()
assert (
mock_exporter.call_args.kwargs["endpoint"]
== adk_template._DEFAULT_TELEMETRY_ENDPOINT
)
Loading
Loading