Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions reflex/utils/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import asyncio
import dataclasses
import json
import pickle
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Callable, Coroutine
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, ClassVar

from reflex.istate.manager.redis import StateManagerRedis
from reflex.state import BaseState, StateUpdate
Expand Down Expand Up @@ -42,7 +42,7 @@ class LostAndFoundRecord:
"""Record for a StateUpdate for a token with its socket on another instance."""

token: str
update: dict[str, Any]
update: StateUpdate


class TokenManager(ABC):
Expand Down Expand Up @@ -328,7 +328,7 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
try:
await self.redis.set(
redis_key,
json.dumps(dataclasses.asdict(socket_record)),
pickle.dumps(socket_record),
ex=self.token_expiration,
)
except Exception as e:
Expand Down Expand Up @@ -386,8 +386,8 @@ async def _subscribe_lost_and_found_updates(
)
async for message in pubsub.listen():
if message["type"] == "pmessage":
record = LostAndFoundRecord(**json.loads(message["data"].decode()))
await emit_update(StateUpdate(**record.update), record.token)
record = pickle.loads(message["data"])
await emit_update(record.update, record.token)

def ensure_lost_and_found_task(
self,
Expand Down Expand Up @@ -424,10 +424,9 @@ async def _get_token_owner(self, token: str, refresh: bool = False) -> str | Non

redis_key = self._get_redis_key(token)
try:
record_json = await self.redis.get(redis_key)
if record_json:
record_data = json.loads(record_json)
socket_record = SocketRecord(**record_data)
record_pkl = await self.redis.get(redis_key)
if record_pkl:
socket_record = pickle.loads(record_pkl)
self.token_to_socket[token] = socket_record
self.sid_to_token[socket_record.sid] = token
return socket_record.instance_id
Expand All @@ -454,11 +453,11 @@ async def emit_lost_and_found(
owner_instance_id = await self._get_token_owner(token)
if owner_instance_id is None:
return False
record = LostAndFoundRecord(token=token, update=dataclasses.asdict(update))
record = LostAndFoundRecord(token=token, update=update)
try:
await self.redis.publish(
f"channel:{self._get_lost_and_found_key(owner_instance_id)}",
json.dumps(dataclasses.asdict(record)),
pickle.dumps(record),
)
except Exception as e:
console.error(f"Redis error publishing lost and found delta: {e}")
Expand Down
21 changes: 10 additions & 11 deletions tests/integration/test_connection_banner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test case for displaying the connection banner when the websocket drops."""

import pickle
from collections.abc import Generator

import pytest
Expand All @@ -10,7 +11,7 @@
from reflex.environment import environment
from reflex.istate.manager.redis import StateManagerRedis
from reflex.testing import AppHarness, WebDriver
from reflex.utils.token_manager import RedisTokenManager
from reflex.utils.token_manager import RedisTokenManager, SocketRecord

from .utils import SessionStorage

Expand Down Expand Up @@ -166,11 +167,10 @@ async def test_connection_banner(connection_banner: AppHarness):
sid_before = app_token_manager.token_to_sid[token]
if isinstance(connection_banner.state_manager, StateManagerRedis):
assert isinstance(app_token_manager, RedisTokenManager)
assert (
await connection_banner.state_manager.redis.get(
app_token_manager._get_redis_key(token)
)
== f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_before}"}}'.encode()
assert await connection_banner.state_manager.redis.get(
app_token_manager._get_redis_key(token)
) == pickle.dumps(
SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_before)
)

delay_button = driver.find_element(By.ID, "delay")
Expand Down Expand Up @@ -226,11 +226,10 @@ async def test_connection_banner(connection_banner: AppHarness):
assert sid_before != sid_after
if isinstance(connection_banner.state_manager, StateManagerRedis):
assert isinstance(app_token_manager, RedisTokenManager)
assert (
await connection_banner.state_manager.redis.get(
app_token_manager._get_redis_key(token)
)
== f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_after}"}}'.encode()
assert await connection_banner.state_manager.redis.get(
app_token_manager._get_redis_key(token)
) == pickle.dumps(
SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_after)
)

# Count should have incremented after coming back up
Expand Down
41 changes: 38 additions & 3 deletions tests/units/utils/test_token_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Unit tests for TokenManager implementations."""

import asyncio
import json
import pickle
import time
from collections.abc import Callable, Generator
from contextlib import asynccontextmanager
Expand All @@ -11,6 +11,7 @@

from reflex import config
from reflex.app import EventNamespace
from reflex.istate.data import RouterData
from reflex.state import StateUpdate
from reflex.utils.token_manager import (
LocalTokenManager,
Expand Down Expand Up @@ -300,7 +301,7 @@ async def test_link_token_to_sid_normal_case(self, manager, mock_redis):
)
mock_redis.set.assert_called_once_with(
f"token_manager_socket_record_{token}",
json.dumps({"instance_id": manager.instance_id, "sid": sid}),
pickle.dumps(SocketRecord(instance_id=manager.instance_id, sid=sid)),
ex=3600,
)
assert manager.token_to_socket[token].sid == sid
Expand Down Expand Up @@ -347,7 +348,7 @@ async def test_link_token_to_sid_duplicate_detected(self, manager, mock_redis):
)
mock_redis.set.assert_called_once_with(
f"token_manager_socket_record_{result}",
json.dumps({"instance_id": manager.instance_id, "sid": sid}),
pickle.dumps(SocketRecord(instance_id=manager.instance_id, sid=sid)),
ex=3600,
)
assert manager.token_to_sid[result] == sid
Expand Down Expand Up @@ -670,3 +671,37 @@ async def test_redis_token_manager_lost_and_found(
emit2_mock.assert_not_called()
emit1_mock.assert_called_once()
emit1_mock.reset_mock()


@pytest.mark.usefixtures("redis_url")
@pytest.mark.asyncio
async def test_redis_token_manager_lost_and_found_router_data(
event_namespace_factory: Callable[[], EventNamespace],
):
"""Updates emitted for lost and found tokens should serialize properly.

Args:
event_namespace_factory: Factory fixture for EventNamespace instances.
"""
event_namespace1 = event_namespace_factory()
emit1_mock: Mock = event_namespace1.emit # pyright: ignore[reportAssignmentType]
event_namespace2 = event_namespace_factory()
emit2_mock: Mock = event_namespace2.emit # pyright: ignore[reportAssignmentType]

await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1"))
await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2"))

router = RouterData.from_router_data(
{"headers": {"x-test": "value"}},
)

await event_namespace2.emit_update(
StateUpdate(delta={"state": {"router": router}}), token="token1"
)
await _wait_for_call_count_positive(emit1_mock)
emit2_mock.assert_not_called()
emit1_mock.assert_called_once()
assert isinstance(emit1_mock.call_args[0][1], StateUpdate)
assert isinstance(emit1_mock.call_args[0][1].delta["state"]["router"], RouterData)
assert emit1_mock.call_args[0][1].delta["state"]["router"] == router
emit1_mock.reset_mock()
Loading