Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ async def send_subscription(
) from e
except Exception as e:
raise e
finally:
output.close()

async def send_stream(
self,
Expand Down Expand Up @@ -335,6 +337,8 @@ async def _encode_stream() -> None:
) from e
except Exception as e:
raise e
finally:
output.close()

async def send_close_stream(
self,
Expand Down
6 changes: 5 additions & 1 deletion replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,11 @@ async def _add_msg_to_stream(
return
try:
await stream.put(msg.payload)
except (RuntimeError, ChannelClosed) as e:
except ChannelClosed:
# The client is no longer interested in this stream,
# just drop the message.
pass
except RuntimeError as e:
raise InvalidMessageException(e) from e

async def _remove_acked_messages_in_buffer(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/common_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def upload_handler(

basic_upload: HandlerMapping = {
("test_service", "upload_method"): (
"upload",
"upload-stream",
upload_method_handler(upload_handler, deserialize_request, serialize_response),
),
}
Expand All @@ -54,7 +54,7 @@ async def subscription_handler(

basic_subscription: HandlerMapping = {
("test_service", "subscription_method"): (
"subscription",
"subscription-stream",
subscription_method_handler(
subscription_handler, deserialize_request, serialize_response
),
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping
from typing import Any, Literal, Mapping

import nanoid
import pytest
Expand All @@ -16,7 +16,8 @@
# Modular fixtures
pytest_plugins = ["tests.river_fixtures.logging", "tests.river_fixtures.clientserver"]

HandlerMapping = Mapping[tuple[str, str], tuple[str, GenericRpcHandler]]
HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"]
HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandler]]


def transport_message(
Expand Down
57 changes: 56 additions & 1 deletion tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import AsyncGenerator

import pytest
from grpc.aio import grpc

from replit_river.client import Client
from replit_river.error_schema import RiverError
from replit_river.rpc import subscription_method_handler
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
from tests.common_handlers import (
basic_rpc_method,
Expand All @@ -14,9 +16,12 @@
basic_upload,
)
from tests.conftest import (
HandlerMapping,
deserialize_error,
deserialize_request,
deserialize_response,
serialize_request,
serialize_response,
)


Expand Down Expand Up @@ -101,6 +106,7 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_subscription}])
async def test_subscription_method(client: Client) -> None:
messages = []
async for response in client.send_subscription(
"test_service",
"subscription_method",
Expand All @@ -110,7 +116,8 @@ async def test_subscription_method(client: Client) -> None:
deserialize_error,
):
assert isinstance(response, str)
assert "Subscription message" in response
messages.append(response)
assert messages == [f"Subscription message {i} for Bob" for i in range(5)]


@pytest.mark.asyncio
Expand Down Expand Up @@ -213,3 +220,51 @@ async def stream_data() -> AsyncGenerator[str, None]:
"Stream response for Stream Data 1",
"Stream response for Stream Data 2",
]


async def flood_subscription_handler(
request: str, context: grpc.aio.ServicerContext
) -> AsyncGenerator[str, None]:
for i in range(1024):
yield f"Subscription message {i} for {request}"


flood_subscription: HandlerMapping = {
("test_service", "flood_subscription_method"): (
"subscription-stream",
subscription_method_handler(
flood_subscription_handler, deserialize_request, serialize_response
),
),
}


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_rpc_method, **flood_subscription}])
async def test_ignore_flood_subscription(client: Client) -> None:
sub = client.send_subscription(
"test_service",
"flood_subscription_method",
"Initial Subscription Data",
serialize_request,
deserialize_response,
deserialize_error,
)

# read one entry to start the subscription
await sub.__anext__()
# close the subscription so we can signal that we're not
# interested in the rest of the subscription.
await sub.aclose()

# ensure that subsequent RPCs still work
response = await client.send_rpc(
"test_service",
"rpc_method",
"Alice",
serialize_request,
deserialize_response,
deserialize_error,
timedelta(seconds=20),
)
assert response == "Hello, Alice!"
Loading