1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import re
16- from typing import Any
15+ from datetime import datetime
1716from unittest import mock
1817
1918from google .adk .events .event import Event
7069)
7170
7271
73- RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$'
74- GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$'
75-
76-
77- class MockApiClient :
78- """Mocks the API Client."""
79-
80- def __init__ (self ) -> None :
81- """Initializes MockClient."""
82- self .async_request = mock .AsyncMock ()
83- self .async_request .side_effect = self ._mock_async_request
84-
85- async def _mock_async_request (
86- self , http_method : str , path : str , request_dict : dict [str , Any ]
87- ):
88- """Mocks the API Client request method."""
89- if http_method == 'POST' :
90- if re .match (GENERATE_MEMORIES_REGEX , path ):
91- return {}
92- elif re .match (RETRIEVE_MEMORIES_REGEX , path ):
93- if (
94- request_dict .get ('scope' , None )
95- and request_dict ['scope' ].get ('app_name' , None ) == MOCK_APP_NAME
96- ):
97- return {
98- 'retrievedMemories' : [
99- {
100- 'memory' : {
101- 'fact' : 'test_content' ,
102- },
103- 'updateTime' : '2024-12-12T12:12:12.123456Z' ,
104- },
105- ],
106- }
107- else :
108- return {'retrievedMemories' : []}
109- else :
110- raise ValueError (f'Unsupported path: { path } ' )
111- else :
112- raise ValueError (f'Unsupported http method: { http_method } ' )
113-
114-
11572def mock_vertex_ai_memory_bank_service ():
11673 """Creates a mock Vertex AI Memory Bank service for testing."""
11774 return VertexAiMemoryBankService (
@@ -122,67 +79,86 @@ def mock_vertex_ai_memory_bank_service():
12279
12380
12481@pytest .fixture
125- def mock_get_api_client ():
126- api_client = MockApiClient ()
82+ def mock_vertexai_client ():
12783 with mock .patch (
128- 'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client' ,
129- return_value = api_client ,
130- ):
131- yield api_client
84+ 'google.adk.memory.vertex_ai_memory_bank_service.vertexai.Client'
85+ ) as mock_client_constructor :
86+ mock_client = mock .MagicMock ()
87+ mock_client .agent_engines .memories .generate = mock .MagicMock ()
88+ mock_client .agent_engines .memories .retrieve = mock .MagicMock ()
89+ mock_client_constructor .return_value = mock_client
90+ yield mock_client
13291
13392
13493@pytest .mark .asyncio
135- @pytest .mark .usefixtures ('mock_get_api_client' )
136- async def test_add_session_to_memory (mock_get_api_client ):
94+ async def test_add_session_to_memory (mock_vertexai_client ):
13795 memory_service = mock_vertex_ai_memory_bank_service ()
13896 await memory_service .add_session_to_memory (MOCK_SESSION )
13997
140- mock_get_api_client .async_request .assert_awaited_once_with (
141- http_method = 'POST' ,
142- path = 'reasoningEngines/123/memories:generate' ,
143- request_dict = {
144- 'direct_contents_source' : {
145- 'events' : [
146- {
147- 'content' : {
148- 'parts' : [
149- {'text' : 'test_content' },
150- ],
151- },
152- },
153- ],
154- },
155- 'scope' : {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
98+ mock_vertexai_client .agent_engines .memories .generate .assert_called_once_with (
99+ name = 'reasoningEngines/123' ,
100+ direct_contents_source = {
101+ 'events' : [
102+ {
103+ 'content' : {
104+ 'parts' : [{'text' : 'test_content' }],
105+ }
106+ }
107+ ]
156108 },
109+ scope = {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
110+ config = {'wait_for_completion' : False },
157111 )
158112
159113
160114@pytest .mark .asyncio
161- @pytest .mark .usefixtures ('mock_get_api_client' )
162- async def test_add_empty_session_to_memory (mock_get_api_client ):
115+ async def test_add_empty_session_to_memory (mock_vertexai_client ):
163116 memory_service = mock_vertex_ai_memory_bank_service ()
164117 await memory_service .add_session_to_memory (MOCK_SESSION_WITH_EMPTY_EVENTS )
165118
166- mock_get_api_client . async_request .assert_not_called ()
119+ mock_vertexai_client . agent_engines . memories . generate .assert_not_called ()
167120
168121
169122@pytest .mark .asyncio
170- @pytest .mark .usefixtures ('mock_get_api_client' )
171- async def test_search_memory (mock_get_api_client ):
123+ async def test_search_memory (mock_vertexai_client ):
124+ retrieved_memory = mock .MagicMock ()
125+ retrieved_memory .memory .fact = 'test_content'
126+ retrieved_memory .memory .update_time = datetime (
127+ 2024 , 12 , 12 , 12 , 12 , 12 , 123456
128+ )
129+
130+ mock_vertexai_client .agent_engines .memories .retrieve .return_value = [
131+ retrieved_memory
132+ ]
172133 memory_service = mock_vertex_ai_memory_bank_service ()
173134
174135 result = await memory_service .search_memory (
175136 app_name = MOCK_APP_NAME , user_id = MOCK_USER_ID , query = 'query'
176137 )
177138
178- mock_get_api_client .async_request .assert_awaited_once_with (
179- http_method = 'POST' ,
180- path = 'reasoningEngines/123/memories:retrieve' ,
181- request_dict = {
182- 'scope' : {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
183- 'similarity_search_params' : {'search_query' : 'query' },
184- },
139+ mock_vertexai_client .agent_engines .memories .retrieve .assert_called_once_with (
140+ name = 'reasoningEngines/123' ,
141+ scope = {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
142+ similarity_search_params = {'search_query' : 'query' },
185143 )
186144
187145 assert len (result .memories ) == 1
188146 assert result .memories [0 ].content .parts [0 ].text == 'test_content'
147+
148+
149+ @pytest .mark .asyncio
150+ async def test_search_memory_empty_results (mock_vertexai_client ):
151+ mock_vertexai_client .agent_engines .memories .retrieve .return_value = []
152+ memory_service = mock_vertex_ai_memory_bank_service ()
153+
154+ result = await memory_service .search_memory (
155+ app_name = MOCK_APP_NAME , user_id = MOCK_USER_ID , query = 'query'
156+ )
157+
158+ mock_vertexai_client .agent_engines .memories .retrieve .assert_called_once_with (
159+ name = 'reasoningEngines/123' ,
160+ scope = {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
161+ similarity_search_params = {'search_query' : 'query' },
162+ )
163+
164+ assert len (result .memories ) == 0
0 commit comments