1+ """Modules for communicating with specific Roborock devices over MQTT."""
2+
3+ import asyncio
14import logging
25from collections .abc import Callable
6+ from json import JSONDecodeError
37
48from roborock .containers import RRiot
9+ from roborock .exceptions import RoborockException
510from roborock .mqtt .session import MqttParams , MqttSession
11+ from roborock .protocol import create_mqtt_decoder , create_mqtt_encoder
12+ from roborock .roborock_message import RoborockMessage
613
714_LOGGER = logging .getLogger (__name__ )
815
916
1017class MqttChannel :
11- """RPC-style channel for communicating with a specific device over MQTT.
18+ """Simple RPC-style channel for communicating with a device over MQTT.
1219
13- This currently only supports listening to messages and does not yet
14- support RPC functionality .
20+ Handles request/response correlation and timeouts, but leaves message
21+ format most parsing to higher-level components .
1522 """
1623
17- def __init__ (self , mqtt_session : MqttSession , duid : str , rriot : RRiot , mqtt_params : MqttParams ):
24+ def __init__ (self , mqtt_session : MqttSession , duid : str , local_key : str , rriot : RRiot , mqtt_params : MqttParams ):
1825 self ._mqtt_session = mqtt_session
1926 self ._duid = duid
27+ self ._local_key = local_key
2028 self ._rriot = rriot
2129 self ._mqtt_params = mqtt_params
2230
31+ # RPC support
32+ self ._waiting_queue : dict [int , asyncio .Future [RoborockMessage ]] = {}
33+ self ._decoder = create_mqtt_decoder (local_key )
34+ self ._encoder = create_mqtt_encoder (local_key )
35+ # Use a regular lock since we need to access from sync callback
36+ self ._queue_lock = asyncio .Lock ()
37+
2338 @property
2439 def _publish_topic (self ) -> str :
2540 """Topic to send commands to the device."""
@@ -30,11 +45,72 @@ def _subscribe_topic(self) -> str:
3045 """Topic to receive responses from the device."""
3146 return f"rr/m/o/{ self ._rriot .u } /{ self ._mqtt_params .username } /{ self ._duid } "
3247
33- async def subscribe (self , callback : Callable [[bytes ], None ]) -> Callable [[], None ]:
48+ async def subscribe (self , callback : Callable [[RoborockMessage ], None ]) -> Callable [[], None ]:
3449 """Subscribe to the device's response topic.
3550
3651 The callback will be called with the message payload when a message is received.
3752
53+ All messages received will be processed through the provided callback, even
54+ those sent in response to the `send_command` command.
55+
3856 Returns a callable that can be used to unsubscribe from the topic.
3957 """
40- return await self ._mqtt_session .subscribe (self ._subscribe_topic , callback )
58+
59+ def message_handler (payload : bytes ) -> None :
60+ if not (messages := self ._decoder (payload )):
61+ _LOGGER .warning ("Failed to decode MQTT message: %s" , payload )
62+ return
63+ for message in messages :
64+ asyncio .create_task (self ._resolve_future_with_lock (message ))
65+ try :
66+ callback (message )
67+ except Exception as e :
68+ _LOGGER .exception ("Uncaught error in message handler callback: %s" , e )
69+
70+ return await self ._mqtt_session .subscribe (self ._subscribe_topic , message_handler )
71+
72+ async def _resolve_future_with_lock (self , message : RoborockMessage ) -> None :
73+ """Resolve waiting future with proper locking."""
74+ if (request_id := message .get_request_id ()) is None :
75+ _LOGGER .debug ("Received message with no request_id" )
76+ return
77+ async with self ._queue_lock :
78+ if (future := self ._waiting_queue .pop (request_id , None )) is not None :
79+ if not future .done ():
80+ future .set_result (message )
81+ else :
82+ _LOGGER .warning ("Received message for completed future: request_id=%s" , request_id )
83+ else :
84+ _LOGGER .warning ("Received message with no waiting handler: request_id=%s" , request_id )
85+
86+ async def send_command (self , message : RoborockMessage , timeout : float = 10.0 ) -> RoborockMessage :
87+ """Send a command message and wait for the response message.
88+
89+ Returns the raw response message - caller is responsible for parsing.
90+ """
91+ try :
92+ if (request_id := message .get_request_id ()) is None :
93+ raise RoborockException ("Message must have a request_id for RPC calls" )
94+ except (ValueError , JSONDecodeError ) as err :
95+ _LOGGER .exception ("Error getting request_id from message: %s" , err )
96+ raise RoborockException (f"Invalid message format, Message must have a request_id: { err } " ) from err
97+
98+ future : asyncio .Future [RoborockMessage ] = asyncio .Future ()
99+ async with self ._queue_lock :
100+ self ._waiting_queue [request_id ] = future
101+
102+ try :
103+ encoded_msg = self ._encoder (message )
104+ await self ._mqtt_session .publish (self ._publish_topic , encoded_msg )
105+
106+ return await asyncio .wait_for (future , timeout = timeout )
107+
108+ except asyncio .TimeoutError as ex :
109+ async with self ._queue_lock :
110+ self ._waiting_queue .pop (request_id , None )
111+ raise RoborockException (f"Command timed out after { timeout } s" ) from ex
112+ except Exception :
113+ logging .exception ("Uncaught error sending command" )
114+ async with self ._queue_lock :
115+ self ._waiting_queue .pop (request_id , None )
116+ raise
0 commit comments