Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ class AnthropicLlm(BaseLlm):
model: str = "claude-sonnet-4-20250514"
max_tokens: int = 8192

client: Optional[Union[AsyncAnthropic, AsyncAnthropicVertex]] = None
"""An optional pre-configured Anthropic client."""

@classmethod
@override
def supported_models(cls) -> list[str]:
Expand Down Expand Up @@ -495,6 +498,8 @@ async def _generate_content_streaming(

@cached_property
def _anthropic_client(self) -> AsyncAnthropic:
if self.client:
return self.client
return AsyncAnthropic()


Expand Down
3 changes: 3 additions & 0 deletions src/google/adk/models/apigee_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def api_client(self) -> Client:
Returns:
The api client.
"""
if self.client:
return self.client

from google.genai import Client

kwargs_for_http_options = {}
Expand Down
11 changes: 11 additions & 0 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TYPE_CHECKING
from typing import Union

from google.genai import Client
from google.genai import types
from google.genai.errors import ClientError
from typing_extensions import override
Expand Down Expand Up @@ -91,6 +92,13 @@ class Gemini(BaseLlm):

model: str = 'gemini-2.5-flash'

client: Optional[Client] = None
"""An optional pre-configured google-genai Client.

When provided, this client will be used for all API calls instead of
constructing a new one from environment variables or other attributes.
"""

base_url: Optional[str] = None
"""The base URL for the AI platform service endpoint."""

Expand Down Expand Up @@ -302,6 +310,9 @@ def api_client(self) -> Client:
Returns:
The api client.
"""
if self.client:
return self.client

from google.genai import Client

return Client(
Expand Down
110 changes: 110 additions & 0 deletions tests/unittests/models/test_custom_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from unittest import mock
from google.genai import Client
from anthropic import AsyncAnthropic
from google.adk.models.google_llm import Gemini
from google.adk.models.anthropic_llm import AnthropicLlm
from google.adk.models.llm_request import LlmRequest
from google.genai import types
from google.genai.types import Content, Part

def test_gemini_custom_client():
"""Verify that Gemini uses the provided custom client."""
mock_client = mock.MagicMock(spec=Client)
gemini = Gemini(model="gemini-1.5-flash", client=mock_client)

assert gemini.api_client is mock_client
# Verify it persists (cached_property)
assert gemini.api_client is mock_client

def test_anthropic_custom_client():
"""Verify that AnthropicLlm uses the provided custom client."""
mock_client = mock.MagicMock(spec=AsyncAnthropic)
anthropic_llm = AnthropicLlm(model="claude-3-5-sonnet-20241022", client=mock_client)

assert anthropic_llm._anthropic_client is mock_client

@pytest.mark.asyncio
async def test_gemini_uses_custom_client_in_call():
"""Verify that Gemini calls use the provided custom client's methods."""
mock_client = mock.MagicMock(spec=Client)
# Mock the nested aio.models.generate_content
mock_aio_models = mock_client.aio.models

gemini = Gemini(model="gemini-1.5-flash", client=mock_client)

request = LlmRequest(
model="gemini-1.5-flash",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])]
)

# Mock the response
mock_response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(role="model", parts=[Part.from_text(text="Hello")]),
finish_reason=types.FinishReason.STOP
)
]
)

async def mock_coro(*args, **kwargs):
return mock_response

mock_aio_models.generate_content.return_value = mock_coro()

# We use stream=False to simplify the mock
responses = [r async for r in gemini.generate_content_async(request, stream=False)]

assert len(responses) == 1
assert responses[0].content.parts[0].text == "Hello"
mock_aio_models.generate_content.assert_called()

@pytest.mark.asyncio
async def test_anthropic_uses_custom_client_in_call():
"""Verify that AnthropicLlm calls use the provided custom client's methods."""
mock_client = mock.MagicMock(spec=AsyncAnthropic)
mock_messages = mock_client.messages

anthropic_llm = AnthropicLlm(model="claude-3-5-sonnet-20241022", client=mock_client)

request = LlmRequest(
model="claude-3-5-sonnet-20241022",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])]
)

from anthropic import types as anthropic_types
mock_response = anthropic_types.Message(
id="msg_test",
content=[anthropic_types.TextBlock(text="Hello", type="text")],
model="claude-3-5-sonnet-20241022",
role="assistant",
stop_reason="end_turn",
type="message",
usage=anthropic_types.Usage(input_tokens=1, output_tokens=1)
)

async def mock_coro(*args, **kwargs):
return mock_response

mock_messages.create.return_value = mock_coro()

responses = [r async for r in anthropic_llm.generate_content_async(request, stream=False)]

assert len(responses) == 1
assert responses[0].content.parts[0].text == "Hello"
mock_messages.create.assert_called()