Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
11fcf17
Remove noqa again, it was not necessary
blast-hardcheese Mar 18, 2025
6aa87be
Bubble state out of heartbeat
blast-hardcheese Mar 18, 2025
b0f989b
Break out heartbeat lifecycle
blast-hardcheese Mar 18, 2025
cfc25ac
Break out "ServerSession" type
blast-hardcheese Mar 18, 2025
bd36123
Flattening Transport into ClientTransport and ServerTransport
blast-hardcheese Mar 18, 2025
1c9e76d
Disambiguate between builtin type
blast-hardcheese Mar 19, 2025
034005f
Split serve() functionality between client and server
blast-hardcheese Mar 19, 2025
6e1b781
Remove handlers from Client*
blast-hardcheese Mar 19, 2025
42ead41
Strip is_server from Session __init__
blast-hardcheese Mar 19, 2025
c8ace77
Adding __init__ to ClientSession
blast-hardcheese Mar 19, 2025
7f0c323
Remove is_server
blast-hardcheese Mar 19, 2025
2db0827
Moving more fields from session to the specialized classes
blast-hardcheese Mar 19, 2025
a68b8ff
Resolving circular import
blast-hardcheese Mar 19, 2025
403a447
Bubble state out of check_to_close_session
blast-hardcheese Mar 19, 2025
b3952d2
Moving add_msg_to_stream out
blast-hardcheese Mar 19, 2025
8d9161e
Moving send_responses_from_output_stream to server_session
blast-hardcheese Mar 19, 2025
36aee31
Turns out _send_buffered_messages was only used in one place
blast-hardcheese Mar 19, 2025
628a8ae
Unused
blast-hardcheese Mar 19, 2025
a426353
Inline
blast-hardcheese Mar 19, 2025
f9a33f3
Inline update_bookkeeping
blast-hardcheese Mar 19, 2025
46afc33
Unnest ExpectedSessionState constructor
blast-hardcheese Mar 19, 2025
3bcfefc
Inlining no-longer-invariant _sessions access
blast-hardcheese Mar 19, 2025
68224c7
Thank you for your service
blast-hardcheese Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/replit_river/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .client import Client
from .error_schema import RiverError
from .rpc import (
GenericRpcHandler,
GenericRpcHandlerBuilder,
GrpcContext,
rpc_method_handler,
stream_method_handler,
Expand All @@ -15,7 +15,7 @@
"Server",
"GrpcContext",
"RiverError",
"GenericRpcHandler",
"GenericRpcHandlerBuilder",
"rpc_method_handler",
"subscription_method_handler",
"upload_method_handler",
Expand Down
140 changes: 138 additions & 2 deletions src/replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import logging
from collections.abc import AsyncIterable
from datetime import timedelta
from typing import Any, AsyncGenerator, Callable
from typing import Any, AsyncGenerator, Callable, Coroutine

import nanoid # type: ignore
import websockets
from aiochannel import Channel
from aiochannel.errors import ChannelClosed
from opentelemetry.trace import Span
from websockets.exceptions import ConnectionClosed

from replit_river.common_session import add_msg_to_stream
from replit_river.error_schema import (
ERROR_CODE_CANCEL,
ERROR_CODE_STREAM_CLOSED,
Expand All @@ -17,10 +20,20 @@
StreamClosedRiverServiceException,
exception_from_message,
)
from replit_river.messages import (
FailedSendingMessageException,
parse_transport_msg,
)
from replit_river.seq_manager import (
IgnoreMessageException,
InvalidMessageException,
OutOfOrderMessageException,
)
from replit_river.session import Session
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions

from .rpc import (
ACK_BIT,
STREAM_CLOSED_BIT,
STREAM_OPEN_BIT,
ErrorType,
Expand All @@ -33,6 +46,129 @@


class ClientSession(Session):
def __init__(
self,
transport_id: str,
to_id: str,
session_id: str,
websocket: websockets.WebSocketCommonProtocol,
transport_options: TransportOptions,
close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]],
retry_connection_callback: (
Callable[
[],
Coroutine[Any, Any, Any],
]
| None
) = None,
) -> None:
super().__init__(
transport_id=transport_id,
to_id=to_id,
session_id=session_id,
websocket=websocket,
transport_options=transport_options,
close_session_callback=close_session_callback,
retry_connection_callback=retry_connection_callback,
)

async def do_close_websocket() -> None:
await self.close_websocket(
self._ws_wrapper,
should_retry=True,
)
await self._begin_close_session_countdown()

self._setup_heartbeats_task(do_close_websocket)

async def start_serve_responses(self) -> None:
self._task_manager.create_task(self.serve())

async def serve(self) -> None:
"""Serve messages from the websocket."""
self._reset_session_close_countdown()
try:
try:
await self._handle_messages_from_ws()
except ConnectionClosed:
if self._retry_connection_callback:
self._task_manager.create_task(self._retry_connection_callback())

await self._begin_close_session_countdown()
logger.debug("ConnectionClosed while serving", exc_info=True)
except FailedSendingMessageException:
# Expected error if the connection is closed.
logger.debug(
"FailedSendingMessageException while serving", exc_info=True
)
except Exception:
logger.exception("caught exception at message iterator")
except ExceptionGroup as eg:
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
if unhandled:
raise ExceptionGroup(
"Unhandled exceptions on River server", unhandled.exceptions
)

async def _handle_messages_from_ws(self) -> None:
logger.debug(
"%s start handling messages from ws %s",
"client",
self._ws_wrapper.id,
)
try:
ws_wrapper = self._ws_wrapper
async for message in ws_wrapper.ws:
try:
if not await ws_wrapper.is_open():
# We should not process messages if the websocket is closed.
break
msg = parse_transport_msg(message, self._transport_options)

logger.debug(f"{self._transport_id} got a message %r", msg)

# Update bookkeeping
await self._seq_manager.check_seq_and_update(msg)
await self._buffer.remove_old_messages(
self._seq_manager.receiver_ack,
)
self._reset_session_close_countdown()

if msg.controlFlags & ACK_BIT != 0:
continue
async with self._stream_lock:
stream = self._streams.get(msg.streamId, None)
if msg.controlFlags & STREAM_OPEN_BIT == 0:
if not stream:
logger.warning("no stream for %s", msg.streamId)
raise IgnoreMessageException(
"no stream for message, ignoring"
)
await add_msg_to_stream(msg, stream)
else:
raise InvalidMessageException(
"Client should not receive stream open bit"
)

if msg.controlFlags & STREAM_CLOSED_BIT != 0:
if stream:
stream.close()
async with self._stream_lock:
del self._streams[msg.streamId]
except IgnoreMessageException:
logger.debug("Ignoring transport message", exc_info=True)
continue
except OutOfOrderMessageException:
logger.exception("Out of order message, closing connection")
await ws_wrapper.close()
return
except InvalidMessageException:
logger.exception("Got invalid transport message, closing session")
await self.close()
return
except ConnectionClosed as e:
raise e

async def send_rpc(
self,
service_name: str,
Expand Down
73 changes: 50 additions & 23 deletions src/replit_river/client_transport.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Generic
from typing import Generic, assert_never

import nanoid
import websockets
from pydantic import ValidationError
from websockets import (
Expand Down Expand Up @@ -36,7 +37,7 @@
IgnoreMessageException,
InvalidMessageException,
)
from replit_river.transport import Transport
from replit_river.session import Session
from replit_river.transport_options import (
HandshakeMetadataType,
TransportOptions,
Expand All @@ -46,19 +47,21 @@
logger = logging.getLogger(__name__)


class ClientTransport(Transport, Generic[HandshakeMetadataType]):
class ClientTransport(Generic[HandshakeMetadataType]):
_sessions: dict[str, ClientSession]

def __init__(
self,
uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]],
client_id: str,
server_id: str,
transport_options: TransportOptions,
):
super().__init__(
transport_id=client_id,
transport_options=transport_options,
is_server=False,
)
self._sessions = {}
self._transport_id = client_id
self._transport_options = transport_options
self._session_lock = asyncio.Lock()

self._uri_and_metadata_factory = uri_and_metadata_factory
self._client_id = client_id
self._server_id = server_id
Expand All @@ -68,6 +71,24 @@ def __init__(
# We want to make sure there's only one session creation at a time
self._create_session_lock = asyncio.Lock()

async def _close_all_sessions(self) -> None:
sessions = self._sessions.values()
logger.info(
f"start closing sessions {self._transport_id}, number sessions : "
f"{len(sessions)}"
)
sessions_to_close = list(sessions)

# closing sessions requires access to the session lock, so we need to close
# them one by one to be safe
for session in sessions_to_close:
await session.close()

logger.info(f"Transport closed {self._transport_id}")

def generate_nanoid(self) -> str:
return str(nanoid.generate())

async def close(self) -> None:
self._rate_limiter.close()
await self._close_all_sessions()
Expand Down Expand Up @@ -201,13 +222,11 @@ async def _create_new_session(
session_id=hs_request.sessionId,
websocket=new_ws,
transport_options=self._transport_options,
is_server=False,
close_session_callback=self._delete_session,
retry_connection_callback=self._retry_connection,
handlers={},
)

self._set_session(new_session)
self._sessions[new_session._to_id] = new_session
await new_session.start_serve_responses()
return new_session

Expand Down Expand Up @@ -297,24 +316,27 @@ async def _establish_handshake(
ControlMessageHandshakeResponse,
]:
try:
expectedSessionState: ExpectedSessionState
match old_session:
case None:
expectedSessionState = ExpectedSessionState(
nextExpectedSeq=0,
nextSentSeq=0,
)
case ClientSession():
expectedSessionState = ExpectedSessionState(
nextExpectedSeq=await old_session.get_next_expected_seq(),
nextSentSeq=await old_session.get_next_sent_seq(),
)
case other:
assert_never(other)
handshake_request = await self._send_handshake_request(
transport_id=transport_id,
to_id=to_id,
session_id=session_id,
handshake_metadata=handshake_metadata,
websocket=websocket,
expected_session_state=ExpectedSessionState(
nextExpectedSeq=(
await old_session.get_next_expected_seq()
if old_session is not None
else 0
),
nextSentSeq=(
await old_session.get_next_sent_seq()
if old_session is not None
else 0
),
),
expected_session_state=expectedSessionState,
)
except FailedSendingMessageException as e:
raise RiverException(
Expand Down Expand Up @@ -352,3 +374,8 @@ async def _establish_handshake(
+ f"{handshake_response.status.reason}",
)
return handshake_request, handshake_response

async def _delete_session(self, session: Session) -> None:
async with self._session_lock:
if session._to_id in self._sessions:
del self._sessions[session._to_id]
1 change: 0 additions & 1 deletion src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@

FILE_HEADER = dedent(
"""\
# ruff: noqa
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
import datetime
Expand Down
2 changes: 1 addition & 1 deletion src/replit_river/codegen/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def add_{service.name}Servicer_to_server(
) -> None:
rpc_method_handlers: Mapping[
tuple[str, str],
tuple[str, river.GenericRpcHandler]
tuple[str, river.GenericRpcHandlerBuilder]
] = {{
"""
),
Expand Down
Loading
Loading