Skip to content

Commit 06b9178

Browse files
committed
feat: Add support for sending/recieving messages
1 parent 341d96d commit 06b9178

File tree

4 files changed

+306
-24
lines changed

4 files changed

+306
-24
lines changed

roborock/devices/device.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from functools import cached_property
1111

1212
from roborock.containers import HomeDataDevice, HomeDataProduct, UserData
13+
from roborock.roborock_message import RoborockMessage
1314

1415
from .mqtt_channel import MqttChannel
1516

@@ -99,9 +100,9 @@ async def close(self) -> None:
99100
self._unsub()
100101
self._unsub = None
101102

102-
def _on_mqtt_message(self, message: bytes) -> None:
103+
def _on_mqtt_message(self, message: RoborockMessage) -> None:
103104
"""Handle incoming MQTT messages from the device.
104105
105106
This method should be overridden in subclasses to handle specific device messages.
106107
"""
107-
_LOGGER.debug("Received message from device %s: %s", self.duid, message[:50]) # Log first 50 bytes for brevity
108+
_LOGGER.debug("Received message from device %s: %s", self.duid, message)

roborock/devices/device_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi)
113113
mqtt_session = await create_mqtt_session(mqtt_params)
114114

115115
def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
116-
mqtt_channel = MqttChannel(mqtt_session, device.duid, user_data.rriot, mqtt_params)
116+
mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
117117
return RoborockDevice(user_data, device, product, mqtt_channel)
118118

119119
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)

roborock/devices/mqtt_channel.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,40 @@
1+
"""Modules for communicating with specific Roborock devices over MQTT."""
2+
3+
import asyncio
14
import logging
25
from collections.abc import Callable
6+
from json import JSONDecodeError
37

48
from roborock.containers import RRiot
9+
from roborock.exceptions import RoborockException
510
from roborock.mqtt.session import MqttParams, MqttSession
11+
from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
12+
from roborock.roborock_message import RoborockMessage
613

714
_LOGGER = logging.getLogger(__name__)
815

916

1017
class MqttChannel:
11-
"""RPC-style channel for communicating with a specific device over MQTT.
18+
"""Simple RPC-style channel for communicating with a device over MQTT.
1219
13-
This currently only supports listening to messages and does not yet
14-
support RPC functionality.
20+
Handles request/response correlation and timeouts, but leaves message
21+
format most parsing to higher-level components.
1522
"""
1623

17-
def __init__(self, mqtt_session: MqttSession, duid: str, rriot: RRiot, mqtt_params: MqttParams):
24+
def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: RRiot, mqtt_params: MqttParams):
1825
self._mqtt_session = mqtt_session
1926
self._duid = duid
27+
self._local_key = local_key
2028
self._rriot = rriot
2129
self._mqtt_params = mqtt_params
2230

31+
# RPC support
32+
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
33+
self._decoder = create_mqtt_decoder(local_key)
34+
self._encoder = create_mqtt_encoder(local_key)
35+
# Use a regular lock since we need to access from sync callback
36+
self._queue_lock = asyncio.Lock()
37+
2338
@property
2439
def _publish_topic(self) -> str:
2540
"""Topic to send commands to the device."""
@@ -30,11 +45,72 @@ def _subscribe_topic(self) -> str:
3045
"""Topic to receive responses from the device."""
3146
return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"
3247

33-
async def subscribe(self, callback: Callable[[bytes], None]) -> Callable[[], None]:
48+
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
3449
"""Subscribe to the device's response topic.
3550
3651
The callback will be called with the message payload when a message is received.
3752
53+
All messages received will be processed through the provided callback, even
54+
those sent in response to the `send_command` command.
55+
3856
Returns a callable that can be used to unsubscribe from the topic.
3957
"""
40-
return await self._mqtt_session.subscribe(self._subscribe_topic, callback)
58+
59+
def message_handler(payload: bytes) -> None:
60+
if not (messages := self._decoder(payload)):
61+
_LOGGER.warning("Failed to decode MQTT message: %s", payload)
62+
return
63+
for message in messages:
64+
asyncio.create_task(self._resolve_future_with_lock(message))
65+
try:
66+
callback(message)
67+
except Exception as e:
68+
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
69+
70+
return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
71+
72+
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
73+
"""Resolve waiting future with proper locking."""
74+
if (request_id := message.get_request_id()) is None:
75+
_LOGGER.debug("Received message with no request_id")
76+
return
77+
async with self._queue_lock:
78+
if (future := self._waiting_queue.pop(request_id, None)) is not None:
79+
if not future.done():
80+
future.set_result(message)
81+
else:
82+
_LOGGER.warning("Received message for completed future: request_id=%s", request_id)
83+
else:
84+
_LOGGER.warning("Received message with no waiting handler: request_id=%s", request_id)
85+
86+
async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
87+
"""Send a command message and wait for the response message.
88+
89+
Returns the raw response message - caller is responsible for parsing.
90+
"""
91+
try:
92+
if (request_id := message.get_request_id()) is None:
93+
raise RoborockException("Message must have a request_id for RPC calls")
94+
except (ValueError, JSONDecodeError) as err:
95+
_LOGGER.exception("Error getting request_id from message: %s", err)
96+
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
97+
98+
future: asyncio.Future[RoborockMessage] = asyncio.Future()
99+
async with self._queue_lock:
100+
self._waiting_queue[request_id] = future
101+
102+
try:
103+
encoded_msg = self._encoder(message)
104+
await self._mqtt_session.publish(self._publish_topic, encoded_msg)
105+
106+
return await asyncio.wait_for(future, timeout=timeout)
107+
108+
except asyncio.TimeoutError as ex:
109+
async with self._queue_lock:
110+
self._waiting_queue.pop(request_id, None)
111+
raise RoborockException(f"Command timed out after {timeout}s") from ex
112+
except Exception:
113+
logging.exception("Uncaught error sending command")
114+
async with self._queue_lock:
115+
self._waiting_queue.pop(request_id, None)
116+
raise

0 commit comments

Comments
 (0)