Skip to content

Commit 84e5c57

Browse files
committed
fix(tools): support secure regional discovery engine endpoints
1 parent a61ccf3 commit 84e5c57

2 files changed

Lines changed: 267 additions & 4 deletions

File tree

src/google/adk/tools/discovery_engine_search_tool.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import re
1718
from typing import Any
1819
from typing import Optional
1920

@@ -25,6 +26,75 @@
2526

2627
from .function_tool import FunctionTool
2728

29+
_DEFAULT_ENDPOINT = "discoveryengine.googleapis.com"
30+
_GLOBAL_LOCATION = "global"
31+
_LOCATION_PATTERN = re.compile(
32+
r"/locations/([a-z0-9-]+)(?:/|$)", flags=re.IGNORECASE
33+
)
34+
_VALID_LOCATION_PATTERN = re.compile(r"^[a-z0-9-]+$")
35+
36+
37+
def _normalize_location(location: str, location_type: str) -> str:
38+
"""Normalizes and validates a location value."""
39+
normalized_location = location.strip().lower()
40+
if not normalized_location:
41+
raise ValueError(f"{location_type} must not be empty if specified.")
42+
if not _VALID_LOCATION_PATTERN.fullmatch(normalized_location):
43+
raise ValueError(
44+
f"{location_type} must contain only letters, digits, and hyphens."
45+
)
46+
return normalized_location
47+
48+
49+
def _extract_resource_location(resource_id: str) -> Optional[str]:
50+
"""Extracts and validates location from a resource id."""
51+
if "/locations/" not in resource_id.lower():
52+
return None
53+
54+
location_match = _LOCATION_PATTERN.search(resource_id)
55+
if not location_match:
56+
raise ValueError("Invalid location in data_store_id or search_engine_id.")
57+
return _normalize_location(location_match.group(1), "resource location")
58+
59+
60+
def _resolve_location(resource_id: str, location: Optional[str]) -> str:
61+
"""Resolves the Discovery Engine location to use for the endpoint."""
62+
inferred_location = _extract_resource_location(resource_id)
63+
64+
if location is not None:
65+
normalized_location = _normalize_location(location, "location")
66+
if inferred_location and normalized_location != inferred_location:
67+
raise ValueError(
68+
"location must match the location in data_store_id or "
69+
"search_engine_id."
70+
)
71+
return normalized_location
72+
73+
if inferred_location:
74+
return inferred_location
75+
return _GLOBAL_LOCATION
76+
77+
78+
def _build_client_options(
79+
resource_id: str,
80+
quota_project_id: Optional[str],
81+
location: Optional[str],
82+
) -> Optional[client_options.ClientOptions]:
83+
"""Builds client options for Discovery Engine requests."""
84+
client_options_kwargs = {}
85+
resolved_location = _resolve_location(resource_id, location)
86+
87+
if resolved_location != _GLOBAL_LOCATION:
88+
client_options_kwargs["api_endpoint"] = (
89+
f"{resolved_location}-{_DEFAULT_ENDPOINT}"
90+
)
91+
if quota_project_id:
92+
client_options_kwargs["quota_project_id"] = quota_project_id
93+
94+
if not client_options_kwargs:
95+
return None
96+
return client_options.ClientOptions(**client_options_kwargs)
97+
2898

2999
class DiscoveryEngineSearchTool(FunctionTool):
30100
"""Tool for searching the discovery engine."""
@@ -38,6 +108,7 @@ def __init__(
38108
search_engine_id: Optional[str] = None,
39109
filter: Optional[str] = None,
40110
max_results: Optional[int] = None,
111+
location: Optional[str] = None,
41112
):
42113
"""Initializes the DiscoveryEngineSearchTool.
43114
@@ -51,6 +122,9 @@ def __init__(
51122
"projects/{project}/locations/{location}/collections/{collection}/engines/{engine}".
52123
filter: The filter to be applied to the search request. Default is None.
53124
max_results: The maximum number of results to return. Default is None.
125+
location: Optional endpoint location override.
126+
Examples: "global", "us", "eu". If not specified, location is inferred
127+
from `data_store_id` or `search_engine_id` and defaults to "global".
54128
"""
55129
super().__init__(self.discovery_engine_search)
56130
if (data_store_id is None and search_engine_id is None) or (
@@ -71,13 +145,15 @@ def __init__(
71145
self._search_engine_id = search_engine_id
72146
self._filter = filter
73147
self._max_results = max_results
148+
self._location = location
74149

75150
credentials, _ = google.auth.default()
76151
quota_project_id = getattr(credentials, "quota_project_id", None)
77-
options = (
78-
client_options.ClientOptions(quota_project_id=quota_project_id)
79-
if quota_project_id
80-
else None
152+
resource_id = data_store_id or search_engine_id or ""
153+
options = _build_client_options(
154+
resource_id=resource_id,
155+
quota_project_id=quota_project_id,
156+
location=location,
81157
)
82158
self._discovery_engine_client = discoveryengine.SearchServiceClient(
83159
credentials=credentials, client_options=options

tests/unittests/tools/test_discovery_engine_search_tool.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,193 @@ def test_init_with_data_store_specs_without_search_engine_id_raises_error(
7979
data_store_id="test_data_store", data_store_specs=[{"id": "123"}]
8080
)
8181

82+
@pytest.mark.parametrize(
83+
("tool_kwargs", "expected_endpoint"),
84+
[
85+
(
86+
{
87+
"data_store_id": (
88+
"projects/test/locations/eu/collections/default_collection/"
89+
"dataStores/test_data_store"
90+
)
91+
},
92+
"eu-discoveryengine.googleapis.com",
93+
),
94+
(
95+
{
96+
"search_engine_id": (
97+
"projects/test/locations/us/collections/default_collection/"
98+
"engines/test_search_engine"
99+
)
100+
},
101+
"us-discoveryengine.googleapis.com",
102+
),
103+
(
104+
{
105+
"data_store_id": (
106+
"projects/test/locations/europe-west1/collections/"
107+
"default_collection/dataStores/test_data_store"
108+
)
109+
},
110+
"europe-west1-discoveryengine.googleapis.com",
111+
),
112+
],
113+
)
114+
@mock.patch.object(discovery_engine_search_tool, "client_options")
115+
@mock.patch.object(discoveryengine, "SearchServiceClient")
116+
def test_init_with_regional_location_uses_regional_endpoint(
117+
self,
118+
mock_search_client,
119+
mock_client_options,
120+
tool_kwargs,
121+
expected_endpoint,
122+
):
123+
"""Test initialization uses the expected regional API endpoint."""
124+
DiscoveryEngineSearchTool(**tool_kwargs)
125+
126+
mock_client_options.ClientOptions.assert_called_once_with(
127+
api_endpoint=expected_endpoint
128+
)
129+
mock_search_client.assert_called_once_with(
130+
credentials="credentials",
131+
client_options=mock_client_options.ClientOptions.return_value,
132+
)
133+
134+
@mock.patch.object(discovery_engine_search_tool, "client_options")
135+
@mock.patch.object(discoveryengine, "SearchServiceClient")
136+
def test_init_with_explicit_location_override_uses_input_location(
137+
self, mock_search_client, mock_client_options
138+
):
139+
"""Test initialization uses explicit location when resource has none."""
140+
DiscoveryEngineSearchTool(
141+
data_store_id="test_data_store",
142+
location="eu",
143+
)
144+
145+
mock_client_options.ClientOptions.assert_called_once_with(
146+
api_endpoint="eu-discoveryengine.googleapis.com"
147+
)
148+
mock_search_client.assert_called_once_with(
149+
credentials="credentials",
150+
client_options=mock_client_options.ClientOptions.return_value,
151+
)
152+
153+
@mock.patch.object(discoveryengine, "SearchServiceClient")
154+
def test_init_with_mismatched_location_raises_error(self, mock_search_client):
155+
"""Test initialization rejects mismatched location overrides."""
156+
with pytest.raises(
157+
ValueError,
158+
match=(
159+
"location must match the location in data_store_id or "
160+
"search_engine_id."
161+
),
162+
):
163+
DiscoveryEngineSearchTool(
164+
data_store_id=(
165+
"projects/test/locations/us/collections/default_collection/"
166+
"dataStores/test_data_store"
167+
),
168+
location="eu",
169+
)
170+
171+
mock_search_client.assert_not_called()
172+
173+
@mock.patch.object(discoveryengine, "SearchServiceClient")
174+
def test_init_with_empty_location_raises_error(self, mock_search_client):
175+
"""Test initialization rejects an empty location override."""
176+
with pytest.raises(
177+
ValueError, match="location must not be empty if specified."
178+
):
179+
DiscoveryEngineSearchTool(
180+
data_store_id=(
181+
"projects/test/locations/us/collections/default_collection/"
182+
"dataStores/test_data_store"
183+
),
184+
location=" ",
185+
)
186+
187+
mock_search_client.assert_not_called()
188+
189+
@mock.patch.object(discoveryengine, "SearchServiceClient")
190+
def test_init_with_invalid_override_location_raises_error(
191+
self, mock_search_client
192+
):
193+
"""Test initialization rejects invalid override location characters."""
194+
with pytest.raises(
195+
ValueError,
196+
match="location must contain only letters, digits, and hyphens.",
197+
):
198+
DiscoveryEngineSearchTool(
199+
data_store_id="test_data_store",
200+
location="attacker.com#",
201+
)
202+
203+
mock_search_client.assert_not_called()
204+
205+
@mock.patch.object(discoveryengine, "SearchServiceClient")
206+
def test_init_with_invalid_resource_location_raises_error(
207+
self, mock_search_client
208+
):
209+
"""Test initialization rejects invalid resource location characters."""
210+
with pytest.raises(
211+
ValueError,
212+
match="Invalid location in data_store_id or search_engine_id.",
213+
):
214+
DiscoveryEngineSearchTool(
215+
data_store_id=(
216+
"projects/test/locations/attacker.com#/collections/"
217+
"default_collection/dataStores/test_data_store"
218+
)
219+
)
220+
221+
mock_search_client.assert_not_called()
222+
223+
@mock.patch.object(discovery_engine_search_tool, "client_options")
224+
@mock.patch.object(discoveryengine, "SearchServiceClient")
225+
def test_init_with_global_location_keeps_default_endpoint(
226+
self, mock_search_client, mock_client_options
227+
):
228+
"""Test initialization keeps default API endpoint for global location."""
229+
DiscoveryEngineSearchTool(
230+
data_store_id=(
231+
"projects/test/locations/global/collections/default_collection/"
232+
"dataStores/test_data_store"
233+
)
234+
)
235+
236+
mock_client_options.ClientOptions.assert_not_called()
237+
mock_search_client.assert_called_once_with(
238+
credentials="credentials", client_options=None
239+
)
240+
241+
@mock.patch.object(discovery_engine_search_tool, "client_options")
242+
@mock.patch.object(discoveryengine, "SearchServiceClient")
243+
def test_init_with_regional_location_and_quota_project_id(
244+
self, mock_search_client, mock_client_options
245+
):
246+
"""Test initialization uses endpoint and quota project id together."""
247+
mock_credentials = mock.MagicMock()
248+
mock_credentials.quota_project_id = "test-quota-project"
249+
250+
with mock.patch.object(
251+
auth, "default", return_value=(mock_credentials, "project")
252+
):
253+
DiscoveryEngineSearchTool(
254+
data_store_id=(
255+
"projects/test/locations/eu/collections/default_collection/"
256+
"dataStores/test_data_store"
257+
)
258+
)
259+
260+
mock_client_options.ClientOptions.assert_called_once_with(
261+
api_endpoint="eu-discoveryengine.googleapis.com",
262+
quota_project_id="test-quota-project",
263+
)
264+
mock_search_client.assert_called_once_with(
265+
credentials=mock_credentials,
266+
client_options=mock_client_options.ClientOptions.return_value,
267+
)
268+
82269
@mock.patch.object(discovery_engine_search_tool, "client_options")
83270
@mock.patch.object(
84271
discoveryengine,

0 commit comments

Comments
 (0)