Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions roborock/devices/local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from roborock.roborock_message import RoborockMessage

from .channel import Channel
from .pending import PendingRpcs

_LOGGER = logging.getLogger(__name__)
_PORT = 58867
Expand Down Expand Up @@ -47,10 +48,9 @@ def __init__(self, host: str, local_key: str):
self._is_connected = False

# RPC support
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs()
self._decoder: Decoder = create_local_decoder(local_key)
self._encoder: Encoder = create_local_encoder(local_key)
self._queue_lock = asyncio.Lock()

@property
def is_connected(self) -> bool:
Expand Down Expand Up @@ -114,11 +114,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
if (request_id := message.get_request_id()) is None:
_LOGGER.debug("Received message with no request_id")
return
async with self._queue_lock:
if (future := self._waiting_queue.pop(request_id, None)) is not None:
future.set_result(message)
else:
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
await self._pending_rpcs.resolve(request_id, message)

async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
"""Send a command message and wait for the response message."""
Expand All @@ -132,24 +128,17 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) ->
_LOGGER.exception("Error getting request_id from message: %s", err)
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err

future: asyncio.Future[RoborockMessage] = asyncio.Future()
async with self._queue_lock:
if request_id in self._waiting_queue:
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
self._waiting_queue[request_id] = future

future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)
try:
encoded_msg = self._encoder(message)
self._transport.write(encoded_msg)
return await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError as ex:
async with self._queue_lock:
self._waiting_queue.pop(request_id, None)
await self._pending_rpcs.pop(request_id)
raise RoborockException(f"Command timed out after {timeout}s") from ex
except Exception:
logging.exception("Uncaught error sending command")
async with self._queue_lock:
self._waiting_queue.pop(request_id, None)
await self._pending_rpcs.pop(request_id)
raise


Expand Down
22 changes: 6 additions & 16 deletions roborock/devices/mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from roborock.roborock_message import RoborockMessage

from .channel import Channel
from .pending import PendingRpcs

_LOGGER = logging.getLogger(__name__)

Expand All @@ -31,10 +32,9 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot:
self._mqtt_params = mqtt_params

# RPC support
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs()
self._decoder = create_mqtt_decoder(local_key)
self._encoder = create_mqtt_encoder(local_key)
self._queue_lock = asyncio.Lock()
self._mqtt_unsub: Callable[[], None] | None = None

@property
Expand Down Expand Up @@ -89,11 +89,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
if (request_id := message.get_request_id()) is None:
_LOGGER.debug("Received message with no request_id")
return
async with self._queue_lock:
if (future := self._waiting_queue.pop(request_id, None)) is not None:
future.set_result(message)
else:
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
await self._pending_rpcs.resolve(request_id, message)

async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
"""Send a command message and wait for the response message.
Expand All @@ -107,11 +103,7 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) ->
_LOGGER.exception("Error getting request_id from message: %s", err)
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err

future: asyncio.Future[RoborockMessage] = asyncio.Future()
async with self._queue_lock:
if request_id in self._waiting_queue:
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
self._waiting_queue[request_id] = future
future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)

try:
encoded_msg = self._encoder(message)
Expand All @@ -120,13 +112,11 @@ async def send_message(self, message: RoborockMessage, timeout: float = 10.0) ->
return await asyncio.wait_for(future, timeout=timeout)

except asyncio.TimeoutError as ex:
async with self._queue_lock:
self._waiting_queue.pop(request_id, None)
await self._pending_rpcs.pop(request_id)
raise RoborockException(f"Command timed out after {timeout}s") from ex
except Exception:
logging.exception("Uncaught error sending command")
async with self._queue_lock:
self._waiting_queue.pop(request_id, None)
await self._pending_rpcs.pop(request_id)
raise


Expand Down
45 changes: 45 additions & 0 deletions roborock/devices/pending.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Module for managing pending RPCs."""

import asyncio
import logging
from typing import Generic, TypeVar

from roborock.exceptions import RoborockException

_LOGGER = logging.getLogger(__name__)


K = TypeVar("K")
V = TypeVar("V")


class PendingRpcs(Generic[K, V]):
"""Manage pending RPCs."""

def __init__(self) -> None:
"""Initialize the pending RPCs."""
self._queue_lock = asyncio.Lock()
self._waiting_queue: dict[K, asyncio.Future[V]] = {}

async def start(self, key: K) -> asyncio.Future[V]:
"""Start the pending RPCs."""
future: asyncio.Future[V] = asyncio.Future()
async with self._queue_lock:
if key in self._waiting_queue:
raise RoborockException(f"Request ID {key} already pending, cannot send command")
self._waiting_queue[key] = future
return future

async def pop(self, key: K) -> None:
"""Pop a pending RPC."""
async with self._queue_lock:
if (future := self._waiting_queue.pop(key, None)) is not None:
future.cancel()

async def resolve(self, key: K, value: V) -> None:
"""Resolve waiting future with proper locking."""
async with self._queue_lock:
if (future := self._waiting_queue.pop(key, None)) is not None:
future.set_result(value)
else:
_LOGGER.debug("Received unsolicited message: %s", key)
Comment on lines +44 to +45
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for B01 - this is going to debug for every message we receive. Probably fine, but may be a bit of log spam for users when they turn on debug mode. Probably not that impactful though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, we can it use an mqtt channel that does not use this. I think i want to move this out into the v1 rpc channel.

75 changes: 75 additions & 0 deletions tests/devices/test_pending.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Tests for the PendingRpcs class."""

import asyncio

import pytest

from roborock.devices.pending import PendingRpcs
from roborock.exceptions import RoborockException


@pytest.fixture(name="pending_rpcs")
def setup_pending_rpcs() -> PendingRpcs[int, str]:
"""Fixture to set up the PendingRpcs for tests."""
return PendingRpcs[int, str]()


async def test_start_duplicate_rpc_raises_exception(pending_rpcs: PendingRpcs[int, str]) -> None:
"""Test that starting a duplicate RPC raises an exception."""
key = 1
await pending_rpcs.start(key)
with pytest.raises(RoborockException, match=f"Request ID {key} already pending, cannot send command"):
await pending_rpcs.start(key)


async def test_resolve_pending_rpc(pending_rpcs: PendingRpcs[int, str]) -> None:
"""Test resolving a pending RPC."""
key = 1
value = "test_result"
future = await pending_rpcs.start(key)
await pending_rpcs.resolve(key, value)
result = await future
assert result == value


async def test_resolve_unsolicited_message(
pending_rpcs: PendingRpcs[int, str], caplog: pytest.LogCaptureFixture
) -> None:
"""Test resolving an unsolicited message does not raise."""
key = 1
value = "test_result"
await pending_rpcs.resolve(key, value)


async def test_pop_pending_rpc(pending_rpcs: PendingRpcs[int, str]) -> None:
"""Test popping a pending RPC, which should cancel the future."""
key = 1
future = await pending_rpcs.start(key)
await pending_rpcs.pop(key)
with pytest.raises(asyncio.CancelledError):
await future


async def test_pop_non_existent_rpc(pending_rpcs: PendingRpcs[int, str]) -> None:
"""Test that popping a non-existent RPC does not raise an exception."""
key = 1
await pending_rpcs.pop(key)


async def test_concurrent_rpcs(pending_rpcs: PendingRpcs[int, str]) -> None:
"""Test handling multiple concurrent RPCs."""

async def start_and_resolve(key: int, value: str) -> str:
future = await pending_rpcs.start(key)
await asyncio.sleep(0.01) # yield
await pending_rpcs.resolve(key, value)
return await future

tasks = [
asyncio.create_task(start_and_resolve(1, "result1")),
asyncio.create_task(start_and_resolve(2, "result2")),
asyncio.create_task(start_and_resolve(3, "result3")),
]

results = await asyncio.gather(*tasks)
assert results == ["result1", "result2", "result3"]