Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions google/genai/_interactions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
from ._utils._logs import setup_logging as _setup_logging
from ._client_adapter import GeminiNextGenAPIClientAdapter, AsyncGeminiNextGenAPIClientAdapter

__all__ = [
"types",
Expand Down Expand Up @@ -96,6 +97,8 @@
"DefaultHttpxClient",
"DefaultAsyncHttpxClient",
"DefaultAioHttpClient",
"AsyncGeminiNextGenAPIClientAdapter",
"GeminiNextGenAPIClientAdapter"
]

if not _t.TYPE_CHECKING:
Expand Down
64 changes: 61 additions & 3 deletions google/genai/_interactions/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from ._utils import is_given, get_async_library
from ._compat import cached_property
from ._models import FinalRequestOptions
from ._version import __version__
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
from ._exceptions import APIStatusError
Expand All @@ -45,6 +46,7 @@
SyncAPIClient,
AsyncAPIClient,
)
from ._client_adapter import GeminiNextGenAPIClientAdapter, AsyncGeminiNextGenAPIClientAdapter

if TYPE_CHECKING:
from .resources import interactions
Expand All @@ -66,6 +68,7 @@ class GeminiNextGenAPIClient(SyncAPIClient):
# client options
api_key: str | None
api_version: str
client_adapter: GeminiNextGenAPIClientAdapter | None

def __init__(
self,
Expand All @@ -81,6 +84,7 @@ def __init__(
# We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
# See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: httpx.Client | None = None,
client_adapter: GeminiNextGenAPIClientAdapter | None = None,
# Enable or disable schema validation for data returned by the API.
# When enabled an error APIResponseValidationError is raised
# if the API responds with invalid data for the expected schema.
Expand Down Expand Up @@ -108,6 +112,8 @@ def __init__(
if base_url is None:
base_url = f"https://generativelanguage.googleapis.com"

self.client_adapter = client_adapter

super().__init__(
version=__version__,
base_url=base_url,
Expand Down Expand Up @@ -159,13 +165,35 @@ def default_headers(self) -> dict[str, str | Omit]:

@override
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
if headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit):
return
if self.api_key and headers.get("x-goog-api-key"):
return
if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
return

raise TypeError(
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
)


@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
return options

headers = options.headers or {}
has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error
if has_auth:
return options

adapted_headers = self.client_adapter.get_auth_headers()
if adapted_headers:
options.headers = {
**adapted_headers,
**headers
}
return options

def copy(
self,
*,
Expand All @@ -179,6 +207,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
client_adapter: GeminiNextGenAPIClientAdapter | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -212,6 +241,7 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
client_adapter=self.client_adapter or client_adapter,
**_extra_kwargs,
)

Expand Down Expand Up @@ -260,6 +290,7 @@ class AsyncGeminiNextGenAPIClient(AsyncAPIClient):
# client options
api_key: str | None
api_version: str
client_adapter: AsyncGeminiNextGenAPIClientAdapter | None

def __init__(
self,
Expand All @@ -275,6 +306,7 @@ def __init__(
# We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
# See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details.
http_client: httpx.AsyncClient | None = None,
client_adapter: AsyncGeminiNextGenAPIClientAdapter | None = None,
# Enable or disable schema validation for data returned by the API.
# When enabled an error APIResponseValidationError is raised
# if the API responds with invalid data for the expected schema.
Expand Down Expand Up @@ -302,6 +334,8 @@ def __init__(
if base_url is None:
base_url = f"https://generativelanguage.googleapis.com"

self.client_adapter = client_adapter

super().__init__(
version=__version__,
base_url=base_url,
Expand Down Expand Up @@ -353,12 +387,34 @@ def default_headers(self) -> dict[str, str | Omit]:

@override
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
if headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit):
return
if self.api_key and headers.get("x-goog-api-key"):
return
if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
return

raise TypeError(
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
)

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
return options

headers = options.headers or {}
has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error
if has_auth:
return options

adapted_headers = await self.client_adapter.async_get_auth_headers()
if adapted_headers:
options.headers = {
**adapted_headers,
**headers
}
return options

def copy(
self,
Expand All @@ -373,6 +429,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
client_adapter: AsyncGeminiNextGenAPIClientAdapter | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -406,6 +463,7 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
client_adapter=self.client_adapter or client_adapter,
**_extra_kwargs,
)

Expand Down
48 changes: 48 additions & 0 deletions google/genai/_interactions/_client_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from abc import ABC, abstractmethod

__all__ = [
"GeminiNextGenAPIClientAdapter",
"AsyncGeminiNextGenAPIClientAdapter"
]

class BaseGeminiNextGenAPIClientAdapter(ABC):
@abstractmethod
def is_vertex_ai(self) -> bool:
...

@abstractmethod
def get_project(self) -> str | None:
...

@abstractmethod
def get_location(self) -> str | None:
...


class AsyncGeminiNextGenAPIClientAdapter(BaseGeminiNextGenAPIClientAdapter):
@abstractmethod
async def async_get_auth_headers(self) -> dict[str, str] | None:
...


class GeminiNextGenAPIClientAdapter(BaseGeminiNextGenAPIClientAdapter):
@abstractmethod
def get_auth_headers(self) -> dict[str, str] | None:
...
Loading
Loading