Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions reflex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@
"utils.imports": ["ImportDict", "ImportVar"],
"utils.misc": ["run_in_thread"],
"utils.serializers": ["serializer"],
"utils.token_manager": ["get_token_manager"],
"vars": ["Var", "field", "Field"],
}

Expand Down
8 changes: 7 additions & 1 deletion reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,7 +2080,11 @@ async def on_connect(self, sid: str, environ: dict):
query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING", ""))
token_list = query_params.get("token", [])
if token_list:
await self.link_token_to_sid(sid, token_list[0])
token = token_list[0]
await self.link_token_to_sid(sid, token)
# Notify lifecycle watchers that this token/sid has connected.
actual_token = self._token_manager.sid_to_token.get(sid, token)
self._token_manager._notify_connect(actual_token, sid)
else:
console.warn(f"No token provided in connection for session {sid}")

Expand All @@ -2102,6 +2106,8 @@ def on_disconnect(self, sid: str) -> asyncio.Task | None:
# Get token before cleaning up
disconnect_token = self.sid_to_token.get(sid)
if disconnect_token:
# Notify lifecycle watchers before cleanup removes the mappings.
self._token_manager._notify_disconnect(disconnect_token, sid)
# Use async cleanup through token manager
task = asyncio.create_task(
self._token_manager.disconnect_token(disconnect_token, sid),
Expand Down
165 changes: 165 additions & 0 deletions reflex/utils/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from redis.asyncio import Redis


class _TokenNotConnectedError(Exception):
"""Raised when a token is not connected."""


def _get_new_token() -> str:
"""Generate a new unique token.

Expand Down Expand Up @@ -56,6 +60,10 @@ def __init__(self):
self.token_to_socket: dict[str, SocketRecord] = {}
# Keep a mapping between socket ID and client token.
self.sid_to_token: dict[str, str] = {}
# Lifecycle events for connect/disconnect notifications.
self._token_disconnect_events: dict[str, list[asyncio.Event]] = {}
self._sid_disconnect_events: dict[str, list[asyncio.Event]] = {}
self._token_connect_events: dict[str, list[asyncio.Event]] = {}

@property
def token_to_sid(self) -> MappingProxyType[str, str]:
Expand Down Expand Up @@ -124,6 +132,145 @@ async def disconnect_all(self):
for token, sid in token_sid_pairs:
await self.disconnect_token(token, sid)

def _notify_connect(self, token: str, sid: str) -> None:
"""Notify lifecycle watchers that a token/sid has connected.

Args:
token: The client token.
sid: The Socket.IO session ID.
"""
for event in self._token_connect_events.pop(token, []):
event.set()

def _notify_disconnect(self, token: str, sid: str) -> None:
"""Notify lifecycle watchers that a token/sid has disconnected.

Args:
token: The client token.
sid: The Socket.IO session ID.
"""
for event in self._token_disconnect_events.pop(token, []):
event.set()
for event in self._sid_disconnect_events.pop(sid, []):
event.set()

async def session_is_connected(self, sid: str) -> AsyncIterator[str]:
"""Yield the client token, then block until the session disconnects.

Yields the client token once, then suspends until the session
disconnects. Use with ``async for`` or ``contextlib.aclosing``.

Args:
sid: The Socket.IO session ID.

Yields:
The client token associated with the session.

Raises:
_TokenNotConnectedError: If the session is not currently connected.
"""
token = self.sid_to_token.get(sid)
if token is None:
raise _TokenNotConnectedError(
f"Session {sid!r} is not currently connected."
)
disconnect_event = asyncio.Event()
self._sid_disconnect_events.setdefault(sid, []).append(disconnect_event)
try:
yield token
await disconnect_event.wait()
finally:
events = self._sid_disconnect_events.get(sid, [])
if disconnect_event in events:
events.remove(disconnect_event)

async def token_is_connected(self, client_token: str) -> AsyncIterator[str]:
"""Yield the session ID, then block until the token disconnects.

Yields the session ID once, then suspends until the token
disconnects. Use with ``async for`` or ``contextlib.aclosing``.

Args:
client_token: The client token.

Yields:
The session ID associated with the token.

Raises:
_TokenNotConnectedError: If the token is not currently connected.
"""
socket_record = self.token_to_socket.get(client_token)
if socket_record is None:
raise _TokenNotConnectedError(
f"Token {client_token!r} is not currently connected."
)
disconnect_event = asyncio.Event()
self._token_disconnect_events.setdefault(client_token, []).append(
disconnect_event
)
try:
yield socket_record.sid
await disconnect_event.wait()
finally:
events = self._token_disconnect_events.get(client_token, [])
if disconnect_event in events:
events.remove(disconnect_event)

def when_session_disconnects(self, sid: str) -> asyncio.Event:
"""Return an asyncio.Event that is set when the session disconnects.

Args:
sid: The Socket.IO session ID.

Returns:
An asyncio.Event that will be set on disconnect.
"""
event = asyncio.Event()
if sid not in self.sid_to_token:
# Already disconnected, set immediately.
event.set()
else:
self._sid_disconnect_events.setdefault(sid, []).append(event)
return event

def when_token_disconnects(self, client_token: str) -> asyncio.Event:
"""Return an asyncio.Event that is set when the token disconnects.

Args:
client_token: The client token.

Returns:
An asyncio.Event that will be set on disconnect.
"""
event = asyncio.Event()
if client_token not in self.token_to_socket:
# Already disconnected, set immediately.
event.set()
else:
self._token_disconnect_events.setdefault(client_token, []).append(
event
)
return event

def when_token_connects(self, client_token: str) -> asyncio.Event:
"""Return an asyncio.Event that is set when the token connects.

Args:
client_token: The client token.

Returns:
An asyncio.Event that will be set on connect.
"""
event = asyncio.Event()
if client_token in self.token_to_socket:
# Already connected, set immediately.
event.set()
else:
self._token_connect_events.setdefault(client_token, []).append(
event
)
return event


class LocalTokenManager(TokenManager):
"""Token manager using local in-memory dictionaries (single worker)."""
Expand Down Expand Up @@ -464,3 +611,21 @@ async def emit_lost_and_found(
else:
return True
return False


def get_token_manager() -> TokenManager:
"""Get the token manager for the currently running app.

Returns:
The active TokenManager instance.

Raises:
RuntimeError: If the app or event namespace is not initialized.
"""
app_mod = prerequisites.get_and_validate_app()
app = app_mod.app
event_namespace = app.event_namespace
if event_namespace is None:
msg = "Event namespace is not initialized. Is the app running?"
raise RuntimeError(msg)
return event_namespace._token_manager
Loading