Skip to content
Open
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .types.spanner import BatchWriteRequest
from .types.spanner import BatchWriteResponse
from .types.spanner import BeginTransactionRequest
from .types.spanner import ClientContext
from .types.spanner import CommitRequest
from .types.spanner import CreateSessionRequest
from .types.spanner import DeleteSessionRequest
Expand Down Expand Up @@ -110,6 +111,7 @@
"BatchWriteRequest",
"BatchWriteResponse",
"BeginTransactionRequest",
"ClientContext",
"CommitRequest",
"CommitResponse",
"CreateSessionRequest",
Expand Down
95 changes: 93 additions & 2 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from google.cloud._helpers import _date_from_iso8601_date
from google.cloud.spanner_v1.types import ExecuteSqlRequest
from google.cloud.spanner_v1.types import TransactionOptions
from google.cloud.spanner_v1.types import ClientContext
from google.cloud.spanner_v1.types import RequestOptions
from google.cloud.spanner_v1.data_types import JsonObject, Interval
from google.cloud.spanner_v1.request_id_header import (
with_request_id,
Expand Down Expand Up @@ -172,15 +174,15 @@ def _merge_query_options(base, merge):
If the resultant object only has empty fields, returns None.
"""
combined = base or ExecuteSqlRequest.QueryOptions()
if type(combined) is dict:
if isinstance(combined, dict):
combined = ExecuteSqlRequest.QueryOptions(
optimizer_version=combined.get("optimizer_version", ""),
optimizer_statistics_package=combined.get(
"optimizer_statistics_package", ""
),
)
merge = merge or ExecuteSqlRequest.QueryOptions()
if type(merge) is dict:
if isinstance(merge, dict):
merge = ExecuteSqlRequest.QueryOptions(
optimizer_version=merge.get("optimizer_version", ""),
optimizer_statistics_package=merge.get("optimizer_statistics_package", ""),
Expand All @@ -191,6 +193,95 @@ def _merge_query_options(base, merge):
return combined


def _merge_client_context(base, merge):
"""Merge higher precedence ClientContext with current ClientContext.

:type base: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict` or None
:param base: The current ClientContext that is intended for use.

:type merge: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict` or None
:param merge:
The ClientContext that has a higher priority than base. These options
should overwrite the fields in base.

:rtype: :class:`~google.cloud.spanner_v1.types.ClientContext`
or None
:returns:
ClientContext object formed by merging the two given ClientContexts.
"""
if base is None and merge is None:
return None

# Avoid in-place modification of base
combined_pb = ClientContext()._pb
if base:
base_pb = ClientContext(base)._pb if isinstance(base, dict) else base._pb
combined_pb.MergeFrom(base_pb)
if merge:
merge_pb = ClientContext(merge)._pb if isinstance(merge, dict) else merge._pb
combined_pb.MergeFrom(merge_pb)

combined = ClientContext(combined_pb)

if not combined.secure_context:
return None
return combined


def _validate_client_context(client_context):
"""Validate and convert client_context.

:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict`
:param client_context: (Optional) Client context to use.

:rtype: :class:`~google.cloud.spanner_v1.types.ClientContext`
:returns: Validated ClientContext object or None.
:raises TypeError: if client_context is not a ClientContext or a dict.
"""
if client_context is not None:
if isinstance(client_context, dict):
client_context = ClientContext(client_context)
elif not isinstance(client_context, ClientContext):
raise TypeError("client_context must be a ClientContext or a dict")
return client_context


def _merge_request_options(request_options, client_context):
"""Merge RequestOptions and ClientContext.

:type request_options: :class:`~google.cloud.spanner_v1.types.RequestOptions`
or :class:`dict` or None
:param request_options: The current RequestOptions that is intended for use.

:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict` or None
:param client_context:
The ClientContext to merge into request_options.

:rtype: :class:`~google.cloud.spanner_v1.types.RequestOptions`
or None
:returns:
RequestOptions object formed by merging the given ClientContext.
"""
if request_options is None and client_context is None:
return None

if request_options is None:
request_options = RequestOptions()
elif isinstance(request_options, dict):
request_options = RequestOptions(request_options)

if client_context:
request_options.client_context = _merge_client_context(
client_context, request_options.client_context
)

return request_options


def _assert_numeric_precision_and_scale(value):
"""
Asserts that input numeric field is within Spanner supported range.
Expand Down
41 changes: 35 additions & 6 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_merge_Transaction_Options,
_merge_client_context,
_merge_request_options,
_validate_client_context,
AtomicCounter,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
Expand All @@ -37,6 +40,7 @@
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
from google.cloud.spanner_v1.types import ClientContext
import time

DEFAULT_RETRY_TIMEOUT_SECS = 30
Expand All @@ -47,9 +51,14 @@ class _BatchBase(_SessionWrapper):

:type session: :class:`~google.cloud.spanner_v1.session.Session`
:param session: the session used to perform the commit

:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict`
:param client_context: (Optional) Client context to use for all requests made
by this batch.
"""

def __init__(self, session):
def __init__(self, session, client_context=None):
super(_BatchBase, self).__init__(session)

self._mutations: List[Mutation] = []
Expand All @@ -58,6 +67,7 @@ def __init__(self, session):
self.committed = None
"""Timestamp at which the batch was successfully committed."""
self.commit_stats: Optional[CommitResponse.CommitStats] = None
self._client_context = _validate_client_context(client_context)

def insert(self, table, columns, values):
"""Insert one or more new table rows.
Expand Down Expand Up @@ -227,10 +237,14 @@ def commit(
txn_options,
)

client_context = _merge_client_context(
database._instance._client._client_context, self._client_context
)
request_options = _merge_request_options(request_options, client_context)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)

request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
Expand Down Expand Up @@ -317,13 +331,25 @@ class MutationGroups(_SessionWrapper):

:type session: :class:`~google.cloud.spanner_v1.session.Session`
:param session: the session used to perform the commit

:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
or :class:`dict`
:param client_context: (Optional) Client context to use for all requests made
by this mutation group.
"""

def __init__(self, session):
def __init__(self, session, client_context=None):
super(MutationGroups, self).__init__(session)
self._mutation_groups: List[MutationGroup] = []
self.committed: bool = False

if client_context is not None:
if isinstance(client_context, dict):
client_context = ClientContext(client_context)
elif not isinstance(client_context, ClientContext):
raise TypeError("client_context must be a ClientContext or a dict")
self._client_context = client_context

def group(self):
"""Returns a new `MutationGroup` to which mutations can be added."""
mutation_group = BatchWriteRequest.MutationGroup()
Expand Down Expand Up @@ -365,10 +391,13 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)

client_context = _merge_client_context(
database._instance._client._client_context, self._client_context
)
request_options = _merge_request_options(request_options, client_context)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)

with trace_call(
name="CloudSpanner.batch_write",
Expand Down
8 changes: 8 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@
from google.cloud.spanner_v1 import __version__
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import DefaultTransactionOptions
from google.cloud.spanner_v1.types import ClientContext
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import _validate_client_context
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1.metrics.constants import (
METRIC_EXPORT_INTERVAL_MS,
Expand Down Expand Up @@ -225,6 +227,10 @@ class Client(ClientWithProject):
:param disable_builtin_metrics: (Optional) Default False. Set to True to disable
the Spanner built-in metrics collection and exporting.

:type client_context: :class:`~google.cloud.spanner_v1.types.RequestOptions.ClientContext`
or :class:`dict`
:param client_context: (Optional) Client context to use for all requests made by this client.

:raises: :class:`ValueError <exceptions.ValueError>` if both ``read_only``
and ``admin`` are :data:`True`
"""
Expand All @@ -251,6 +257,7 @@ def __init__(
default_transaction_options: Optional[DefaultTransactionOptions] = None,
experimental_host=None,
disable_builtin_metrics=False,
client_context=None,
):
self._emulator_host = _get_spanner_emulator_host()
self._experimental_host = experimental_host
Expand Down Expand Up @@ -287,6 +294,7 @@ def __init__(

# Environment flag config has higher precedence than application config.
self._query_options = _merge_query_options(query_options, env_query_options)
self._client_context = _validate_client_context(client_context)

if self._emulator_host is not None and (
"http://" in self._emulator_host or "https://" in self._emulator_host
Expand Down
Loading
Loading