Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion reflex/istate/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def _mark_dirty(
Raises:
ImmutableStateError: if the StateProxy is not mutable.
"""
if not self._self_state._is_mutable():
if not self._self_state._is_mutable(): # pyright: ignore[reportAttributeAccessIssue]
msg = (
"Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state."
Expand Down
10 changes: 10 additions & 0 deletions reflex/istate/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ async def _link_to(self, token: str) -> Self:
if not token:
msg = "Cannot link shared state to empty token."
raise ReflexRuntimeError(msg)
if not isinstance(self, SharedState):
msg = "Can only link SharedState instances."
raise RuntimeError(msg)
if self._linked_to == token:
return self # already linked to this token
if self._linked_to and self._linked_to != token:
Expand All @@ -215,6 +218,10 @@ async def _unlink(self):
"""
from reflex.istate.manager import get_state_manager

if not isinstance(self, SharedState):
msg = "Can only unlink SharedState instances."
raise ReflexRuntimeError(msg)

state_name = self.get_full_name()
if (
not self._reflex_internal_links
Expand Down Expand Up @@ -272,6 +279,9 @@ async def _internal_patch_linked_state(
_substate_key(token, type(self))
)
linked_state = await linked_root_state.get_state(type(self))
if not isinstance(linked_state, SharedState):
msg = f"Linked state for token {token} is not a SharedState."
raise ReflexRuntimeError(msg)
# Avoid unnecessary dirtiness of shared state when there are no changes.
if type(self) not in self._held_locks[token]:
self._held_locks[token][type(self)] = linked_state
Expand Down
5 changes: 4 additions & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ def _check_overwritten_dynamic_args(cls, args: list[str]):
for substate in cls.get_substates():
substate._check_overwritten_dynamic_args(args)

def __getattribute__(self, name: str) -> Any:
def _get_attribute(self, name: str) -> Any:
"""Get the state var.

If the var is inherited, get the var from the parent state.
Expand Down Expand Up @@ -1408,6 +1408,9 @@ def __getattribute__(self, name: str) -> Any:

return value

if not TYPE_CHECKING:
__getattribute__ = _get_attribute

def __setattr__(self, name: str, value: Any):
"""Set the attribute.

Expand Down
37 changes: 19 additions & 18 deletions tests/integration/test_client_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def set_sub_sub(var: str, value: str):
# Ensure the state is gone (not hydrated)
async def poll_for_not_hydrated():
state = await client_side.get_state(_substate_key(token or "", state_name))
assert isinstance(state, State)
return not state.is_hydrated

assert await AppHarness._poll_for_async(poll_for_not_hydrated)
Expand Down Expand Up @@ -723,30 +724,30 @@ async def get_sub_state():

async def poll_for_c1_set():
sub_state = await get_sub_state()
return sub_state.c1 == "c1 post expire"
return sub_state.c1 == "c1 post expire" # pyright: ignore[reportAttributeAccessIssue]

assert await AppHarness._poll_for_async(poll_for_c1_set)
sub_state = await get_sub_state()
assert sub_state.c1 == "c1 post expire"
assert sub_state.c2 == "c2 value"
assert sub_state.c3 == ""
assert sub_state.c4 == "c4 value"
assert sub_state.c5 == "c5 value"
assert sub_state.c6 == "c6 value"
assert sub_state.c7 == "c7 value"
assert sub_state.l1 == "l1 value"
assert sub_state.l2 == "l2 value"
assert sub_state.l3 == "l3 value"
assert sub_state.l4 == "l4 value"
assert sub_state.s1 == "s1 value"
assert sub_state.s2 == "s2 value"
assert sub_state.s3 == "s3 value"
assert sub_state.c1 == "c1 post expire" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.c2 == "c2 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.c3 == "" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.c4 == "c4 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.c5 == "c5 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.c6 == "c6 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.c7 == "c7 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.l1 == "l1 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.l2 == "l2 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.l3 == "l3 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.l4 == "l4 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.s1 == "s1 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.s2 == "s2 value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_state.s3 == "s3 value" # pyright: ignore[reportAttributeAccessIssue]
sub_sub_state = sub_state.substates[
client_side.get_state_name("_client_side_sub_sub_state")
]
assert sub_sub_state.c1s == "c1s value"
assert sub_sub_state.l1s == "l1s value"
assert sub_sub_state.s1s == "s1s value"
assert sub_sub_state.c1s == "c1s value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_sub_state.l1s == "l1s value" # pyright: ignore[reportAttributeAccessIssue]
assert sub_sub_state.s1s == "s1s value" # pyright: ignore[reportAttributeAccessIssue]

# clear the cookie jar and local storage, ensure state reset to default
driver.delete_all_cookies()
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_component_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ async def test_component_state_app(component_state_app: AppHarness):
a_state = root_state.substates[a_state_name]
b_state = root_state.substates[b_state_name]
assert a_state._backend_vars != a_state.backend_vars
assert a_state._be == a_state._backend_vars["_be"] == 3
assert b_state._be is None
assert a_state._be == a_state._backend_vars["_be"] == 3 # pyright: ignore[reportAttributeAccessIssue]
assert b_state._be is None # pyright: ignore[reportAttributeAccessIssue]
assert b_state._backend_vars["_be"] is None

assert count_b.text == "0"
Expand All @@ -183,7 +183,7 @@ async def test_component_state_app(component_state_app: AppHarness):
a_state = root_state.substates[a_state_name]
b_state = root_state.substates[b_state_name]
assert b_state._backend_vars != b_state.backend_vars
assert b_state._be == b_state._backend_vars["_be"] == 2
assert b_state._be == b_state._backend_vars["_be"] == 2 # pyright: ignore[reportAttributeAccessIssue]

# Check locally-defined substate style
count_c = driver.find_element(By.ID, "count-c")
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/test_computed_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ async def test_computed_vars(
token = f"{token}_{full_state_name}"
state = (await computed_vars.get_state(token)).substates[state_name]
assert state is not None
assert state.count1_backend == 0
assert state._count1_backend == 0
assert state.count1_backend == 0 # pyright: ignore[reportAttributeAccessIssue]
assert state._count1_backend == 0 # pyright: ignore[reportAttributeAccessIssue]

# test that backend var is not rendered
count1_backend = driver.find_element(By.ID, "count1_backend")
Expand Down Expand Up @@ -259,9 +259,9 @@ async def test_computed_vars(
)
state = (await computed_vars.get_state(token)).substates[state_name]
assert state is not None
assert state.count1_backend == 1
assert state.count1_backend == 1 # pyright: ignore[reportAttributeAccessIssue]
assert count1_backend.text == ""
assert state._count1_backend == 1
assert state._count1_backend == 1 # pyright: ignore[reportAttributeAccessIssue]
assert count1_backend_.text == ""

mark_dirty.click()
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/test_dynamic_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DynamicState(rx.State):

@rx.event
def on_load(self):
page_data = f"{self.router.page.path}-{self.page_id or 'no page id'}"
page_data = f"{self.router.page.path}-{self.page_id or 'no page id'}" # pyright: ignore[reportAttributeAccessIssue]
print(f"on_load: {page_data}")
self.order.append(page_data)

Expand All @@ -43,7 +43,7 @@ def on_load_static(self):
@rx.var
def next_page(self) -> str:
try:
return str(int(self.page_id) + 1)
return str(int(self.page_id) + 1) # pyright: ignore[reportAttributeAccessIssue]
except ValueError:
return "0"

Expand Down Expand Up @@ -81,7 +81,7 @@ class ArgState(rx.State):

@rx.var(cache=False)
def arg(self) -> int:
return int(self.arg_str or 0)
return int(self.arg_str or 0) # pyright: ignore[reportAttributeAccessIssue]

class ArgSubState(ArgState):
@rx.var
Expand All @@ -90,7 +90,7 @@ def cached_arg(self) -> int:

@rx.var
def cached_arg_str(self) -> str:
return self.arg_str
return self.arg_str # pyright: ignore[reportAttributeAccessIssue]

@rx.page(route="/arg/[arg_str]")
def arg() -> rx.Component:
Expand Down Expand Up @@ -238,11 +238,11 @@ async def _backend_state():
async def _check():
return (await _backend_state()).substates[
dynamic_state_name
].order == exp_order
].order == exp_order # pyright: ignore[reportAttributeAccessIssue]

await AppHarness._poll_for_async(_check, timeout=10)
assert (
list((await _backend_state()).substates[dynamic_state_name].order)
list((await _backend_state()).substates[dynamic_state_name].order) # pyright: ignore[reportAttributeAccessIssue]
== exp_order
)

Expand Down
8 changes: 4 additions & 4 deletions tests/integration/test_event_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ def poll_for_order(

async def _poll_for_order(exp_order: list[str]):
async def _check():
return (await _backend_state(event_action, token)).order == exp_order
return (await _backend_state(event_action, token)).order == exp_order # pyright: ignore[reportAttributeAccessIssue]

await AppHarness._poll_for_async(_check)
assert (await _backend_state(event_action, token)).order == exp_order
assert (await _backend_state(event_action, token)).order == exp_order # pyright: ignore[reportAttributeAccessIssue]

return _poll_for_order

Expand Down Expand Up @@ -358,12 +358,12 @@ async def test_event_actions_throttle_debounce(
# Wait until the debounce event shows up
async def _debounce_received():
state = await _backend_state(event_action, token)
return state.order and state.order[-1] == "on_click_debounce"
return state.order and state.order[-1] == "on_click_debounce" # pyright: ignore[reportAttributeAccessIssue]

await AppHarness._poll_for_async(_debounce_received)

# This test is inherently racy, so ensure the `on_click_throttle` event is fired approximately the expected number of times.
final_event_order = (await _backend_state(event_action, token)).order
final_event_order = (await _backend_state(event_action, token)).order # pyright: ignore[reportAttributeAccessIssue]
n_on_click_throttle_received = final_event_order.count("on_click_throttle")
print(
f"Expected ~{exp_events} on_click_throttle events, received {n_on_click_throttle_received}"
Expand Down
14 changes: 7 additions & 7 deletions tests/integration/test_event_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,11 @@ async def test_event_chain_click(

async def _has_all_events():
return len(
(await event_chain.get_state(token)).substates[state_name].event_order
(await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue]
) == len(exp_event_order)

await AppHarness._poll_for_async(_has_all_events)
event_order = (await event_chain.get_state(token)).substates[state_name].event_order
event_order = (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue]
assert event_order == exp_event_order


Expand Down Expand Up @@ -515,13 +515,13 @@ async def test_event_chain_on_load(

async def _has_all_events():
return len(
(await event_chain.get_state(token)).substates[state_name].event_order
(await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue]
) == len(exp_event_order)

await AppHarness._poll_for_async(_has_all_events)
backend_state = (await event_chain.get_state(token)).substates[state_name]
assert backend_state.event_order == exp_event_order
assert backend_state.is_hydrated is True
assert backend_state.event_order == exp_event_order # pyright: ignore[reportAttributeAccessIssue]
assert backend_state.is_hydrated is True # pyright: ignore[reportAttributeAccessIssue]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -582,11 +582,11 @@ async def test_event_chain_on_mount(

async def _has_all_events():
return len(
(await event_chain.get_state(token)).substates[state_name].event_order
(await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue]
) == len(exp_event_order)

await AppHarness._poll_for_async(_has_all_events)
event_order = (await event_chain.get_state(token)).substates[state_name].event_order
event_order = (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue]
assert list(event_order) == exp_event_order


Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_form_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ async def get_form_data():
return (
(await form_submit.get_state(f"{token}_{full_state_name}"))
.substates[state_name]
.form_data
.form_data # pyright: ignore[reportAttributeAccessIssue]
)

# wait for the form data to arrive at the backend
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):

async def get_state_text():
state = await fully_controlled_input.get_state(f"{token}_{full_state_name}")
return state.substates[state_name].text
return state.substates[state_name].text # pyright: ignore[reportAttributeAccessIssue]

# ensure defaults are set correctly
assert (
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_linked_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ async def unlink(self):

@rx.event
async def on_load_link_default(self):
linked_state = await self._link_to(self.room or "default")
if self.room:
assert linked_state._linked_to == self.room
linked_state = await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue]
if self.room: # pyright: ignore[reportAttributeAccessIssue]
assert linked_state._linked_to == self.room # pyright: ignore[reportAttributeAccessIssue]
else:
assert linked_state._linked_to == "default"

Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_login_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def LoginSample():
class State(rx.State):
auth_token: str = rx.LocalStorage("")

@rx.event
def set_auth_token(self, token: str):
self.auth_token = token

@rx.event
def logout(self):
self.set_auth_token("")
Expand Down
16 changes: 8 additions & 8 deletions tests/integration/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,15 @@ async def test_upload_file(

state = await upload_file.get_state(substate_token)
# only the secondary form tracks progress and chain events
assert state.substates[state_name].event_order.count("upload_progress") == 1
assert state.substates[state_name].event_order.count("chain_event") == 1
assert state.substates[state_name].event_order.count("upload_progress") == 1 # pyright: ignore[reportAttributeAccessIssue]
assert state.substates[state_name].event_order.count("chain_event") == 1 # pyright: ignore[reportAttributeAccessIssue]

# look up the backend state and assert on uploaded contents
async def get_file_data():
return (
(await upload_file.get_state(substate_token))
.substates[state_name]
._file_data
._file_data # pyright: ignore[reportAttributeAccessIssue]
)

file_data = await AppHarness._poll_for_async(get_file_data)
Expand Down Expand Up @@ -358,7 +358,7 @@ async def get_file_data():
return (
(await upload_file.get_state(substate_token))
.substates[state_name]
._file_data
._file_data # pyright: ignore[reportAttributeAccessIssue]
)

file_data = await AppHarness._poll_for_async(get_file_data)
Expand Down Expand Up @@ -469,7 +469,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
# Get interim progress dicts saved in the on_upload_progress handler.
async def _progress_dicts():
state = await upload_file.get_state(substate_token)
return state.substates[state_name].progress_dicts
return state.substates[state_name].progress_dicts # pyright: ignore[reportAttributeAccessIssue]

# We should have _some_ progress
assert await AppHarness._poll_for_async(_progress_dicts)
Expand All @@ -479,7 +479,7 @@ async def _progress_dicts():
assert p["progress"] != 1

state = await upload_file.get_state(substate_token)
file_data = state.substates[state_name]._file_data
file_data = state.substates[state_name]._file_data # pyright: ignore[reportAttributeAccessIssue]
assert isinstance(file_data, dict)
normalized_file_data = {Path(k).name: v for k, v in file_data.items()}
assert Path(exp_name).name not in normalized_file_data
Expand Down Expand Up @@ -575,11 +575,11 @@ async def test_on_drop(

async def exp_name_in_quaternary():
state = await upload_file.get_state(substate_token)
return exp_name in state.substates[state_name].quaternary_names
return exp_name in state.substates[state_name].quaternary_names # pyright: ignore[reportAttributeAccessIssue]

# Poll until the file names appear in the display
await AppHarness._poll_for_async(exp_name_in_quaternary)

# Verify through state that the file names were captured correctly
state = await upload_file.get_state(substate_token)
assert exp_name in state.substates[state_name].quaternary_names
assert exp_name in state.substates[state_name].quaternary_names # pyright: ignore[reportAttributeAccessIssue]
Loading
Loading