Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def _establish_new_connection(

try:
uri_and_metadata = await self._uri_and_metadata_factory()
ws = await websockets.connect(uri_and_metadata["uri"])
ws = await websockets.connect(uri_and_metadata["uri"], max_size=None)
session_id = (
self.generate_nanoid()
if not old_session
Expand Down
5 changes: 4 additions & 1 deletion src/replit_river/v2/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,10 @@ async def _do_ensure_connected[HandshakeMetadata](
ws: ClientConnection | None = None
try:
uri_and_metadata = await uri_and_metadata_factory()
ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"])
ws = await websockets.asyncio.client.connect(
uri_and_metadata["uri"],
max_size=None,
)
transition_connecting(ws)

try:
Expand Down
108 changes: 104 additions & 4 deletions tests/v2/test_v2_session_lifecycle.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
import asyncio
import logging
from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict

import msgpack
import nanoid
import pytest
from websockets import ConnectionClosedOK
from websockets import ConnectionClosed, ConnectionClosedOK
from websockets.asyncio.server import ServerConnection, serve
from websockets.typing import Data

from replit_river.common_session import SessionState
from replit_river.messages import parse_transport_msg
from replit_river.rate_limiter import RateLimiter
from replit_river.rpc import TransportMessage
from replit_river.rpc import (
ControlMessageHandshakeRequest,
ControlMessageHandshakeResponse,
HandShakeStatus,
TransportMessage,
)
from replit_river.transport_options import TransportOptions, UriAndMetadata
from replit_river.v2.session import Session
from replit_river.v2.client import Client
from replit_river.v2.session import STREAM_CLOSED_BIT, Session


class _PermissiveRateLimiter(RateLimiter):
Expand Down Expand Up @@ -54,6 +63,8 @@ async def handle(websocket: ServerConnection) -> None:
await recv.put(datagram)
except ConnectionClosedOK:
pass
except ConnectionClosed:
pass

port: int | None = None
if state["ipv4_laddr"]:
Expand All @@ -65,7 +76,10 @@ async def handle(websocket: ServerConnection) -> None:
state["ipv4_laddr"] = pair
serve_forever = asyncio.create_task(server.serve_forever())
yield None
serve_forever.cancel()
server.close()
await server.wait_closed()
# "serve_forever" should always be done after wait_closed finishes
assert serve_forever.done()


@pytest.fixture
Expand Down Expand Up @@ -145,3 +159,89 @@ def close_session_callback(_session: Session) -> None:
await connecting
assert session._state == SessionState.CLOSED
assert callcount == 1


async def test_big_packet(ws_server: WsServerFixture) -> None:
(urimeta, recv, conn) = ws_server

client = Client(
client_id="CLIENT1",
server_id="SERVER",
transport_options=TransportOptions(),
uri_and_metadata_factory=urimeta,
)

connecting = asyncio.create_task(client.ensure_connected())
request_msg = parse_transport_msg(await recv.get())

assert not isinstance(request_msg, str)
assert (serverconn := conn())
handshake_request: ControlMessageHandshakeRequest[None] = (
ControlMessageHandshakeRequest(**request_msg.payload)
)

handshake_resp = ControlMessageHandshakeResponse(
status=HandShakeStatus(
ok=True,
),
)
handshake_request.sessionId

msg = TransportMessage(
from_=request_msg.from_,
to=request_msg.to,
streamId=request_msg.streamId,
controlFlags=0,
id=nanoid.generate(),
seq=0,
ack=0,
payload=handshake_resp.model_dump(),
)
packed = msgpack.packb(
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
)
await serverconn.send(packed)

async def handle_server_messages() -> None:
request_msg = parse_transport_msg(await recv.get())
assert not isinstance(request_msg, str)
msg = TransportMessage(
from_=request_msg.to,
to=request_msg.from_,
streamId=request_msg.streamId,
controlFlags=STREAM_CLOSED_BIT,
id=nanoid.generate(),
seq=0,
ack=0,
payload={
"ok": True,
"payload": {
"big": "a" * (2**20 + 1), # One more than the default max_size
},
},
)

packed = msgpack.packb(
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
)
await serverconn.send(packed)

stream_close_msg = msgpack.unpackb(await recv.get())
assert stream_close_msg["controlFlags"] == STREAM_CLOSED_BIT

stream_handler = asyncio.create_task(handle_server_messages())

try:
async for datagram in client.send_subscription(
"test", "bigstream", {}, lambda x: x, lambda x: x, lambda x: x
):
print(datagram)
except Exception:
logging.exception("Interrupted")

await client.close()
await connecting

# Ensure we're listening to close messages as well
stream_handler.cancel()
await stream_handler
Loading