Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/aio/query/test_query_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ydb
from ydb import QueryExplainResultFormat
from ydb.aio.query.session import QuerySession
from ydb.connection import EndpointKey


def _check_session_not_ready(session: QuerySession):
Expand Down Expand Up @@ -161,3 +162,82 @@ async def callee(session: QuerySession):
assert "Lookup" in plan_lookup_string
finally:
await pool.execute_with_retries("DROP TABLE test_explain")


class TestAsyncQuerySessionPreferredEndpoint:
def test_endpoint_key_is_none_before_create(self, session: QuerySession):
assert session._endpoint_key is None

@pytest.mark.asyncio
async def test_endpoint_key_is_set_after_create(self, session: QuerySession):
await session.create()
assert session.node_id is not None
assert session._endpoint_key is not None
assert isinstance(session._endpoint_key, EndpointKey)
assert session._endpoint_key.node_id == session.node_id

@pytest.mark.asyncio
async def test_session_uses_preferred_endpoint_on_execute(self, session: QuerySession):
await session.create()
original_driver_call = session._driver

calls = []

async def mock_driver_call(*args, **kwargs):
calls.append(kwargs)
return await original_driver_call(*args, **kwargs)

session._driver = mock_driver_call

async with await session.execute("select 1;") as results:
async for _ in results:
pass

assert len(calls) > 0
assert "preferred_endpoint" in calls[0]
assert calls[0]["preferred_endpoint"] is not None
assert calls[0]["preferred_endpoint"].node_id == session.node_id

@pytest.mark.asyncio
async def test_session_uses_preferred_endpoint_on_delete(self, session: QuerySession):
await session.create()
original_driver_call = session._driver

calls = []

async def mock_driver_call(*args, **kwargs):
calls.append(kwargs)
return await original_driver_call(*args, **kwargs)

session._driver = mock_driver_call

await session.delete()

assert len(calls) > 0
assert "preferred_endpoint" in calls[0]
assert calls[0]["preferred_endpoint"] is not None
assert calls[0]["preferred_endpoint"].node_id == session.node_id

@pytest.mark.asyncio
async def test_transaction_uses_preferred_endpoint(self, session: QuerySession):
await session.create()
original_driver_call = session._driver

calls = []

async def mock_driver_call(*args, **kwargs):
calls.append(kwargs)
return await original_driver_call(*args, **kwargs)

session._driver = mock_driver_call

async with session.transaction() as tx:
async with await tx.execute("select 1;") as results:
async for _ in results:
pass

execute_calls = [c for c in calls if "preferred_endpoint" in c]
assert len(execute_calls) > 0
for call in execute_calls:
assert call["preferred_endpoint"] is not None
assert call["preferred_endpoint"].node_id == session.node_id
76 changes: 76 additions & 0 deletions tests/query/test_query_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ydb import QuerySessionPool
from ydb.query.base import QueryStatsMode, QueryExplainResultFormat
from ydb.query.session import QuerySession
from ydb.connection import EndpointKey


def _check_session_not_ready(session: QuerySession):
Expand Down Expand Up @@ -226,3 +227,78 @@ def callee(session: QuerySession):
assert "Lookup" in plan_lookup_string
finally:
pool.execute_with_retries("DROP TABLE test_explain")


class TestQuerySessionPreferredEndpoint:
def test_endpoint_key_is_none_before_create(self, session: QuerySession):
assert session._endpoint_key is None

def test_endpoint_key_is_set_after_create(self, session: QuerySession):
session.create()
assert session.node_id is not None
assert session._endpoint_key is not None
assert isinstance(session._endpoint_key, EndpointKey)
assert session._endpoint_key.node_id == session.node_id

def test_session_uses_preferred_endpoint_on_execute(self, session: QuerySession):
session.create()
original_driver_call = session._driver

calls = []

def mock_driver_call(*args, **kwargs):
calls.append(kwargs)
return original_driver_call(*args, **kwargs)

session._driver = mock_driver_call

with session.execute("select 1;") as results:
for _ in results:
pass

assert len(calls) > 0
assert "preferred_endpoint" in calls[0]
assert calls[0]["preferred_endpoint"] is not None
assert calls[0]["preferred_endpoint"].node_id == session.node_id

def test_session_uses_preferred_endpoint_on_delete(self, session: QuerySession):
session.create()
original_driver_call = session._driver

calls = []

def mock_driver_call(*args, **kwargs):
calls.append(kwargs)
return original_driver_call(*args, **kwargs)

session._driver = mock_driver_call

session.delete()

assert len(calls) > 0
assert "preferred_endpoint" in calls[0]
assert calls[0]["preferred_endpoint"] is not None
assert calls[0]["preferred_endpoint"].node_id == session.node_id

def test_transaction_uses_preferred_endpoint(self, session: QuerySession):
session.create()
original_driver_call = session._driver

calls = []

def mock_driver_call(*args, **kwargs):
calls.append(kwargs)
return original_driver_call(*args, **kwargs)

session._driver = mock_driver_call

with session.transaction() as tx:
with tx.execute("select 1;") as results:
for _ in results:
pass

execute_calls = [c for c in calls if "preferred_endpoint" in c]
assert len(execute_calls) > 0
for call in execute_calls:
assert call["preferred_endpoint"] is not None
assert call["preferred_endpoint"].node_id == session.node_id
11 changes: 10 additions & 1 deletion ydb/query/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .. import _apis, issues, _utilities
from ..settings import BaseRequestSettings
from ..connection import _RpcState as RpcState
from ..connection import _RpcState as RpcState, EndpointKey
from .._grpc.grpcwrapper import common_utils
from .._grpc.grpcwrapper import ydb_query as _ydb_query
from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public
Expand Down Expand Up @@ -85,6 +85,12 @@ def node_id(self) -> Optional[int]:
def is_active(self) -> bool:
return self._session_id is not None and not self._closed

@property
def _endpoint_key(self) -> Optional[EndpointKey]:
if self._node_id is None:
return None
return EndpointKey(endpoint=None, node_id=self._node_id)

@property
def is_closed(self) -> bool:
return self._closed
Expand Down Expand Up @@ -142,6 +148,7 @@ def _delete_call(self, settings: Optional[BaseRequestSettings] = None) -> "BaseQ
wrap_result=wrapper_delete_session,
wrap_args=(self,),
settings=settings,
preferred_endpoint=self._endpoint_key,
)

def _attach_call(self) -> Iterable[_apis.ydb_query.SessionState]:
Expand All @@ -150,6 +157,7 @@ def _attach_call(self) -> Iterable[_apis.ydb_query.SessionState]:
_apis.QueryService.Stub,
_apis.QueryService.AttachSession,
settings=self._attach_settings,
preferred_endpoint=self._endpoint_key,
)

def _execute_call(
Expand Down Expand Up @@ -189,6 +197,7 @@ def _execute_call(
_apis.QueryService.Stub,
_apis.QueryService.ExecuteQuery,
settings=settings,
preferred_endpoint=self._endpoint_key,
)


Expand Down
4 changes: 4 additions & 0 deletions ydb/query/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxCo
wrap_tx_begin_response,
settings,
(self.session, self._tx_state, self),
preferred_endpoint=self.session._endpoint_key,
)

def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
Expand All @@ -272,6 +273,7 @@ def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxC
wrap_tx_commit_response,
settings,
(self.session, self._tx_state, self),
preferred_endpoint=self.session._endpoint_key,
)

def _rollback_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
Expand All @@ -285,6 +287,7 @@ def _rollback_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryT
wrap_tx_rollback_response,
settings,
(self.session, self._tx_state, self),
preferred_endpoint=self.session._endpoint_key,
)

def _execute_call(
Expand Down Expand Up @@ -327,6 +330,7 @@ def _execute_call(
_apis.QueryService.Stub,
_apis.QueryService.ExecuteQuery,
settings=settings,
preferred_endpoint=self.session._endpoint_key,
)

def _move_to_beginned(self, tx_id: str) -> None:
Expand Down
Loading