Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions src/frequenz/client/base/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from typing import Generic, Literal, TypeAlias, TypeVar, overload
from typing import AsyncIterable, Generic, Literal, TypeAlias, TypeVar, overload

import grpc.aio

Expand All @@ -19,10 +19,6 @@
_logger = logging.getLogger(__name__)


RequestT = TypeVar("RequestT")
"""The request type of the stream."""


InputT = TypeVar("InputT")
"""The input type of the stream."""

Expand Down Expand Up @@ -80,20 +76,31 @@ class GrpcStreamBroadcaster(Generic[InputT, OutputT]):

Example:
```python
from typing import Any
from frequenz.client.base import (
GrpcStreamBroadcaster,
StreamFatalError,
StreamRetrying,
StreamStarted,
)
from frequenz.channels import Receiver
from frequenz.channels import Receiver # Assuming Receiver is available

# Dummy async iterable for demonstration
async def async_range(fail_after: int = -1) -> AsyncIterable[int]:
for i in range(10):
if fail_after != -1 and i >= fail_after:
raise grpc.aio.AioRpcError(
code=grpc.StatusCode.UNAVAILABLE,
initial_metadata=grpc.aio.Metadata(),
trailing_metadata=grpc.aio.Metadata(),
details="Simulated error"
)
yield i
await asyncio.sleep(0.1)

async def main():
stub: Any = ... # The gRPC stub
streamer = GrpcStreamBroadcaster(
stream_name="example_stream",
stream_method=stub.MyStreamingMethod,
stream_method=lambda: async_range(fail_after=3),
transform=lambda msg: msg * 2, # transform messages
retry_on_exhausted_stream=False,
)
Expand Down Expand Up @@ -149,7 +156,7 @@ async def consume_data():
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
stream_name: str,
stream_method: Callable[[], grpc.aio.UnaryStreamCall[RequestT, InputT]],
stream_method: Callable[[], AsyncIterable[InputT]],
transform: Callable[[InputT], OutputT],
retry_strategy: retry.Strategy | None = None,
retry_on_exhausted_stream: bool = False,
Expand Down Expand Up @@ -275,22 +282,14 @@ async def _run(self) -> None:

while True:
error: Exception | None = None
first_message_received = False
_logger.info("%s: starting to stream", self._stream_name)
try:
call = self._stream_method()

# We await for the initial metadata before sending a
# StreamStarted event. This is the best indication we have of a
# successful connection without delaying it until the first
# message is received, which might happen a long time after the
# "connection" was established.
await call.initial_metadata()
if self._event_sender:
await self._event_sender.send(StreamStarted())

async for msg in call:
first_message_received = True
try:
transformed = self._transform(msg)
except Exception: # pylint: disable=broad-exception-caught
Expand All @@ -306,9 +305,6 @@ async def _run(self) -> None:
except grpc.aio.AioRpcError as err:
error = err

if first_message_received:
self._retry_strategy.reset()

if error is None and not self._retry_on_exhausted_stream:
_logger.info(
"%s: connection closed, stream exhausted", self._stream_name
Expand Down
120 changes: 11 additions & 109 deletions tests/streaming/test_grpc_stream_broadcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import asyncio
import logging
from collections.abc import AsyncIterator, Callable
from collections.abc import AsyncIterator
from contextlib import AsyncExitStack
from datetime import timedelta
from unittest import mock
Expand Down Expand Up @@ -56,18 +56,6 @@ def make_error() -> grpc.aio.AioRpcError:
)


def unary_stream_call_mock(
name: str, side_effect: Callable[[], AsyncIterator[object]]
) -> mock.MagicMock:
"""Create a new mocked unary stream call."""
# Sadly we can't use spec here because grpc.aio.UnaryStreamCall seems to be
# dynamic and mock doesn't find `__aiter__` in it when creating the spec.
call_mock = mock.MagicMock(name=name)
call_mock.__aiter__.side_effect = side_effect
call_mock.initial_metadata = mock.AsyncMock()
return call_mock


@pytest.fixture
async def ok_helper(
no_retry: mock.MagicMock, # pylint: disable=redefined-outer-name
Expand All @@ -83,15 +71,9 @@ async def asynciter() -> AsyncIterator[int]:
yield i
await asyncio.sleep(0) # Yield control to the event loop

rpc_mock = mock.MagicMock(
name="ok_helper_method",
side_effect=lambda: unary_stream_call_mock(
"ok_helper_unary_stream_call", asynciter
),
)
helper = streaming.GrpcStreamBroadcaster(
stream_name="test_helper",
stream_method=rpc_mock,
stream_method=asynciter,
transform=_transformer,
retry_strategy=no_retry,
retry_on_exhausted_stream=retry_on_exhausted_stream,
Expand Down Expand Up @@ -140,31 +122,6 @@ async def __anext__(self) -> int:
raise self._error
return self._current

async def initial_metadata(self) -> None:
"""Mock initial metadata method."""
if self._current >= self._num_successes:
raise self._error


def erroring_rpc_mock(
error: Exception,
ready_event: asyncio.Event,
*,
num_successes: int = 0,
should_error_on_initial_metadata_too: bool = False,
) -> mock.MagicMock:
"""Fixture for mocked erroring rpc."""
# In this case we want to keep the state of the erroring call
erroring_iter = _ErroringAsyncIter(error, ready_event, num_successes=num_successes)
call_mock = unary_stream_call_mock(
"erroring_unary_stream_call", lambda: erroring_iter
)
if should_error_on_initial_metadata_too:
call_mock.initial_metadata.side_effect = erroring_iter.initial_metadata
rpc_mock = mock.MagicMock(name="erroring_rpc", return_value=call_mock)

return rpc_mock


@pytest.mark.parametrize("retry_on_exhausted_stream", [True])
async def test_streaming_success_retry_on_exhausted(
Expand Down Expand Up @@ -256,7 +213,7 @@ async def test_streaming_error( # pylint: disable=too-many-arguments

helper = streaming.GrpcStreamBroadcaster(
stream_name="test_helper",
stream_method=erroring_rpc_mock(
stream_method=lambda: _ErroringAsyncIter(
error, receiver_ready_event, num_successes=successes
),
transform=_transformer,
Expand Down Expand Up @@ -316,9 +273,7 @@ async def asynciter() -> AsyncIterator[int]:

rpc_mock = mock.MagicMock(
name="ok_helper_method",
side_effect=lambda: unary_stream_call_mock(
"ok_helper_unary_stream_call", asynciter
),
side_effect=asynciter,
)

helper = streaming.GrpcStreamBroadcaster(
Expand Down Expand Up @@ -388,7 +343,7 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments
mock_retry.get_progress.return_value = "mock progress"
helper = streaming.GrpcStreamBroadcaster(
stream_name="test_helper",
stream_method=erroring_rpc_mock(error, receiver_ready_event),
stream_method=lambda: _ErroringAsyncIter(error, receiver_ready_event),
transform=_transformer,
retry_strategy=mock_retry,
)
Expand Down Expand Up @@ -422,18 +377,10 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments
]


@pytest.mark.parametrize(
"include_events", [True, False], ids=["with_events", "without_events"]
)
@pytest.mark.parametrize(
"error_in_metadata",
[True, False],
ids=["with_initial_metadata_error", "iterator_error_only"],
)
@pytest.mark.parametrize("include_events", [True, False])
async def test_messages_on_retry(
receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name
include_events: bool,
error_in_metadata: bool,
) -> None:
"""Test that messages are sent on retry."""
# We need to use a specific instance for all the test here because 2 errors created
Expand All @@ -443,11 +390,8 @@ async def test_messages_on_retry(

helper = streaming.GrpcStreamBroadcaster(
stream_name="test_helper",
stream_method=erroring_rpc_mock(
error,
receiver_ready_event,
num_successes=2,
should_error_on_initial_metadata_too=error_in_metadata,
stream_method=lambda: _ErroringAsyncIter(
error, receiver_ready_event, num_successes=2
),
transform=_transformer,
retry_strategy=retry.LinearBackoff(limit=1, interval=0.0, jitter=0.0),
Expand All @@ -466,57 +410,15 @@ async def test_messages_on_retry(
assert items == [
"transformed_0",
"transformed_1",
"transformed_0",
"transformed_1",
]
if include_events:
extra_events: list[StreamEvent] = []
if not error_in_metadata:
extra_events.append(StreamStarted())
assert events == [
StreamStarted(),
StreamRetrying(timedelta(seconds=0.0), error),
*extra_events,
StreamStarted(),
StreamFatalError(error),
]
else:
assert events == []


@mock.patch(
"frequenz.client.base.streaming.asyncio.sleep", autospec=True, wraps=asyncio.sleep
)
async def test_retry_reset(
mock_sleep: mock.MagicMock,
receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name
) -> None:
"""Test that retry strategy resets after a successful start."""
# Use a mock retry strategy so we can assert reset() was called.
mock_retry = mock.MagicMock(spec=retry.Strategy)
# Simulate one retry interval then exhaustion.
mock_retry.next_interval.side_effect = [0.01, 0.01, None]
mock_retry.copy.return_value = mock_retry
mock_retry.get_progress.return_value = "mock progress"

# The rpc will yield one message then raise, so the strategy should be reset
# after the successful start (i.e. after first message received).
helper = streaming.GrpcStreamBroadcaster(
stream_name="test_helper",
stream_method=erroring_rpc_mock(
make_error(), receiver_ready_event, num_successes=1
),
transform=_transformer,
retry_strategy=mock_retry,
retry_on_exhausted_stream=True,
)

async with AsyncExitStack() as stack:
stack.push_async_callback(helper.stop)

receiver = helper.new_receiver()
receiver_ready_event.set()
_ = await _split_message(receiver)

# reset() should have been called once after the successful start.
mock_retry.reset.assert_called_once()

# One sleep for the single retry interval.
mock_sleep.assert_has_calls([mock.call(0.01), mock.call(0.01)])
Loading