Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions src/replit_river/v2/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ class Session[HandshakeMetadata]:

# Terminating
_terminating_task: asyncio.Task[None] | None
_closing_waiter: asyncio.Event | None

def __init__(
self,
Expand Down Expand Up @@ -229,7 +228,6 @@ def __init__(

# Terminating
self._terminating_task = None
self._closing_waiter = None

self._start_recv_from_ws()
self._start_buffered_message_sender()
Expand Down Expand Up @@ -393,11 +391,11 @@ async def close(
reason: Exception | None = None,
) -> None:
"""Close the session and all associated streams."""
if self._closing_waiter:
if self._terminating_task:
try:
logger.debug("Session already closing, waiting...")
async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC):
await self._closing_waiter.wait()
await self._terminating_task
except asyncio.TimeoutError:
logger.warning(
f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} "
Expand Down Expand Up @@ -436,7 +434,6 @@ async def do_close() -> None:
f"ws: {self._ws}"
)
self._state = SessionState.CLOSING
self._closing_waiter = asyncio.Event()

# We're closing, so we need to wake up...
# ... tasks waiting for connection to be established
Expand Down Expand Up @@ -502,14 +499,11 @@ async def do_close() -> None:
# This will get us GC'd, so this should be the last thing.
self._close_session_callback(self)

# Release waiters, then release the event
self._closing_waiter.set()
self._closing_waiter = None

if self._terminating_task:
return self._terminating_task

return asyncio.create_task(do_close())
self._terminating_task = asyncio.create_task(do_close())
return self._terminating_task

def _start_buffered_message_sender(
self,
Expand Down
16 changes: 14 additions & 2 deletions tests/v2/test_v2_session_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from websockets.asyncio.server import ServerConnection, serve
from websockets.typing import Data

from replit_river.common_session import SessionState
from replit_river.messages import parse_transport_msg
from replit_river.rate_limiter import RateLimiter
from replit_river.rpc import TransportMessage
Expand Down Expand Up @@ -114,14 +115,20 @@ async def test_connect(ws_server: WsServerFixture) -> None:
await connecting


async def test_reconnect(ws_server: WsServerFixture) -> None:
async def test_close_race(ws_server: WsServerFixture) -> None:
(urimeta, recv, conn) = ws_server

callcount = 0

def close_session_callback(_session: Session) -> None:
nonlocal callcount
callcount += 1

session = Session(
server_id="SERVER",
session_id="SESSION1",
transport_options=TransportOptions(),
close_session_callback=lambda _: None,
close_session_callback=close_session_callback,
client_id="CLIENT1",
rate_limiter=_PermissiveRateLimiter(),
uri_and_metadata_factory=urimeta,
Expand All @@ -132,4 +139,9 @@ async def test_reconnect(ws_server: WsServerFixture) -> None:
assert isinstance(msg, TransportMessage)
assert msg.payload["type"] == "HANDSHAKE_REQ"
await session.close()
await session.close()
await session.close()
await session.close()
await connecting
assert session._state == SessionState.CLOSED
assert callcount == 1
Loading