-
Notifications
You must be signed in to change notification settings - Fork 60
Extract common module for managing pending RPCs #451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| """Module for managing pending RPCs.""" | ||
|
|
||
| import asyncio | ||
| import logging | ||
| from typing import Generic, TypeVar | ||
|
|
||
| from roborock.exceptions import RoborockException | ||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| K = TypeVar("K") | ||
| V = TypeVar("V") | ||
|
|
||
|
|
||
| class PendingRpcs(Generic[K, V]): | ||
| """Manage pending RPCs.""" | ||
|
|
||
| def __init__(self) -> None: | ||
| """Initialize the pending RPCs.""" | ||
| self._queue_lock = asyncio.Lock() | ||
| self._waiting_queue: dict[K, asyncio.Future[V]] = {} | ||
|
|
||
| async def start(self, key: K) -> asyncio.Future[V]: | ||
| """Start the pending RPCs.""" | ||
| future: asyncio.Future[V] = asyncio.Future() | ||
| async with self._queue_lock: | ||
| if key in self._waiting_queue: | ||
| raise RoborockException(f"Request ID {key} already pending, cannot send command") | ||
| self._waiting_queue[key] = future | ||
| return future | ||
|
|
||
| async def pop(self, key: K) -> None: | ||
| """Pop a pending RPC.""" | ||
| async with self._queue_lock: | ||
| if (future := self._waiting_queue.pop(key, None)) is not None: | ||
| future.cancel() | ||
|
|
||
| async def resolve(self, key: K, value: V) -> None: | ||
| """Resolve waiting future with proper locking.""" | ||
| async with self._queue_lock: | ||
| if (future := self._waiting_queue.pop(key, None)) is not None: | ||
| future.set_result(value) | ||
| else: | ||
| _LOGGER.debug("Received unsolicited message: %s", key) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| """Tests for the PendingRpcs class.""" | ||
|
|
||
| import asyncio | ||
|
|
||
| import pytest | ||
|
|
||
| from roborock.devices.pending import PendingRpcs | ||
| from roborock.exceptions import RoborockException | ||
|
|
||
|
|
||
| @pytest.fixture(name="pending_rpcs") | ||
| def setup_pending_rpcs() -> PendingRpcs[int, str]: | ||
| """Fixture to set up the PendingRpcs for tests.""" | ||
| return PendingRpcs[int, str]() | ||
|
|
||
|
|
||
| async def test_start_duplicate_rpc_raises_exception(pending_rpcs: PendingRpcs[int, str]) -> None: | ||
| """Test that starting a duplicate RPC raises an exception.""" | ||
| key = 1 | ||
| await pending_rpcs.start(key) | ||
| with pytest.raises(RoborockException, match=f"Request ID {key} already pending, cannot send command"): | ||
| await pending_rpcs.start(key) | ||
|
|
||
|
|
||
| async def test_resolve_pending_rpc(pending_rpcs: PendingRpcs[int, str]) -> None: | ||
| """Test resolving a pending RPC.""" | ||
| key = 1 | ||
| value = "test_result" | ||
| future = await pending_rpcs.start(key) | ||
| await pending_rpcs.resolve(key, value) | ||
| result = await future | ||
| assert result == value | ||
|
|
||
|
|
||
| async def test_resolve_unsolicited_message( | ||
| pending_rpcs: PendingRpcs[int, str], caplog: pytest.LogCaptureFixture | ||
| ) -> None: | ||
| """Test resolving an unsolicited message does not raise.""" | ||
| key = 1 | ||
| value = "test_result" | ||
| await pending_rpcs.resolve(key, value) | ||
|
|
||
|
|
||
| async def test_pop_pending_rpc(pending_rpcs: PendingRpcs[int, str]) -> None: | ||
| """Test popping a pending RPC, which should cancel the future.""" | ||
| key = 1 | ||
| future = await pending_rpcs.start(key) | ||
| await pending_rpcs.pop(key) | ||
| with pytest.raises(asyncio.CancelledError): | ||
| await future | ||
|
|
||
|
|
||
| async def test_pop_non_existent_rpc(pending_rpcs: PendingRpcs[int, str]) -> None: | ||
| """Test that popping a non-existent RPC does not raise an exception.""" | ||
| key = 1 | ||
| await pending_rpcs.pop(key) | ||
|
|
||
|
|
||
| async def test_concurrent_rpcs(pending_rpcs: PendingRpcs[int, str]) -> None: | ||
| """Test handling multiple concurrent RPCs.""" | ||
|
|
||
| async def start_and_resolve(key: int, value: str) -> str: | ||
| future = await pending_rpcs.start(key) | ||
| await asyncio.sleep(0.01) # yield | ||
| await pending_rpcs.resolve(key, value) | ||
| return await future | ||
|
|
||
| tasks = [ | ||
| asyncio.create_task(start_and_resolve(1, "result1")), | ||
| asyncio.create_task(start_and_resolve(2, "result2")), | ||
| asyncio.create_task(start_and_resolve(3, "result3")), | ||
| ] | ||
|
|
||
| results = await asyncio.gather(*tasks) | ||
| assert results == ["result1", "result2", "result3"] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for B01 - this is going to debug for every message we receive. Probably fine, but may be a bit of log spam for users when they turn on debug mode. Probably not that impactful though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, we can it use an mqtt channel that does not use this. I think i want to move this out into the v1 rpc channel.