Skip to content

Commit 0004721

Browse files
authored
feat: fix a01 and b01 response handling in new api (#453)
* feat: fix a01 and b01 response handling in new api * chore: remove pending rpcs object * fix: remove query_values response * chore: update logging and comments * fix: Update mqtt channel to correctly handle multiple subscribers * chore: remove unnecessary whitespace
1 parent 241b166 commit 0004721

19 files changed

+462
-647
lines changed

roborock/devices/a01_channel.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
"""Thin wrapper around the MQTT channel for Roborock A01 devices."""
22

3-
from __future__ import annotations
4-
3+
import asyncio
54
import logging
65
from typing import Any, overload
76

7+
from roborock.exceptions import RoborockException
88
from roborock.protocols.a01_protocol import (
99
decode_rpc_response,
1010
encode_mqtt_payload,
1111
)
12-
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
12+
from roborock.roborock_message import (
13+
RoborockDyadDataProtocol,
14+
RoborockMessage,
15+
RoborockZeoProtocol,
16+
)
1317

1418
from .mqtt_channel import MqttChannel
1519

1620
_LOGGER = logging.getLogger(__name__)
21+
_TIMEOUT = 10.0
22+
23+
# Both RoborockDyadDataProtocol and RoborockZeoProtocol have the same
24+
# value for ID_QUERY
25+
_ID_QUERY = int(RoborockDyadDataProtocol.ID_QUERY)
1726

1827

1928
@overload
@@ -39,5 +48,46 @@ async def send_decoded_command(
3948
"""Send a command on the MQTT channel and get a decoded response."""
4049
_LOGGER.debug("Sending MQTT command: %s", params)
4150
roborock_message = encode_mqtt_payload(params)
42-
response = await mqtt_channel.send_message(roborock_message)
43-
return decode_rpc_response(response) # type: ignore[return-value]
51+
52+
# For commands that set values: send the command and do not
53+
# block waiting for a response. Queries are handled below.
54+
param_values = {int(k): v for k, v in params.items()}
55+
if not (query_values := param_values.get(_ID_QUERY)):
56+
await mqtt_channel.publish(roborock_message)
57+
return {}
58+
59+
# Merge any results together than contain the requested data. This
60+
# does not use a future since it needs to merge results across responses.
61+
# This could be simplified if we can assume there is a single response.
62+
finished = asyncio.Event()
63+
result: dict[int, Any] = {}
64+
65+
def find_response(response_message: RoborockMessage) -> None:
66+
"""Handle incoming messages and resolve the future."""
67+
try:
68+
decoded = decode_rpc_response(response_message)
69+
except RoborockException as ex:
70+
_LOGGER.info("Failed to decode a01 message: %s: %s", response_message, ex)
71+
return
72+
for key, value in decoded.items():
73+
if key in query_values:
74+
result[key] = value
75+
if len(result) != len(query_values):
76+
_LOGGER.debug("Incomplete query response: %s != %s", result, query_values)
77+
return
78+
_LOGGER.debug("Received query response: %s", result)
79+
if not finished.is_set():
80+
finished.set()
81+
82+
unsub = await mqtt_channel.subscribe(find_response)
83+
84+
try:
85+
await mqtt_channel.publish(roborock_message)
86+
try:
87+
await asyncio.wait_for(finished.wait(), timeout=_TIMEOUT)
88+
except TimeoutError as ex:
89+
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
90+
finally:
91+
unsub()
92+
93+
return result # type: ignore[return-value]

roborock/devices/b01_channel.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import Any
76

87
from roborock.protocols.b01_protocol import (
98
CommandType,
109
ParamsType,
11-
decode_rpc_response,
1210
encode_mqtt_payload,
1311
)
1412

@@ -22,9 +20,8 @@ async def send_decoded_command(
2220
dps: int,
2321
command: CommandType,
2422
params: ParamsType,
25-
) -> dict[int, Any]:
23+
) -> None:
2624
"""Send a command on the MQTT channel and get a decoded response."""
2725
_LOGGER.debug("Sending MQTT command: %s", params)
2826
roborock_message = encode_mqtt_payload(dps, command, params)
29-
response = await mqtt_channel.send_message(roborock_message)
30-
return decode_rpc_response(response) # type: ignore[return-value]
27+
await mqtt_channel.publish(roborock_message)

roborock/devices/local_channel.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44
import logging
55
from collections.abc import Callable
66
from dataclasses import dataclass
7-
from json import JSONDecodeError
87

98
from roborock.exceptions import RoborockConnectionException, RoborockException
109
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
1110
from roborock.roborock_message import RoborockMessage
1211

1312
from .channel import Channel
14-
from .pending import PendingRpcs
1513

1614
_LOGGER = logging.getLogger(__name__)
1715
_PORT = 58867
@@ -47,8 +45,6 @@ def __init__(self, host: str, local_key: str):
4745
self._subscribers: list[Callable[[RoborockMessage], None]] = []
4846
self._is_connected = False
4947

50-
# RPC support
51-
self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs()
5248
self._decoder: Decoder = create_local_decoder(local_key)
5349
self._encoder: Encoder = create_local_encoder(local_key)
5450

@@ -87,7 +83,6 @@ def _data_received(self, data: bytes) -> None:
8783
return
8884
for message in messages:
8985
_LOGGER.debug("Received message: %s", message)
90-
asyncio.create_task(self._resolve_future_with_lock(message))
9186
for callback in self._subscribers:
9287
try:
9388
callback(message)
@@ -109,37 +104,24 @@ def unsubscribe() -> None:
109104

110105
return unsubscribe
111106

112-
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
113-
"""Resolve waiting future with proper locking."""
114-
if (request_id := message.get_request_id()) is None:
115-
_LOGGER.debug("Received message with no request_id")
116-
return
117-
await self._pending_rpcs.resolve(request_id, message)
107+
async def publish(self, message: RoborockMessage) -> None:
108+
"""Send a command message.
118109
119-
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
120-
"""Send a command message and wait for the response message."""
110+
The caller is responsible for associating the message with its response.
111+
"""
121112
if not self._transport or not self._is_connected:
122113
raise RoborockConnectionException("Not connected to device")
123114

124-
try:
125-
if (request_id := message.get_request_id()) is None:
126-
raise RoborockException("Message must have a request_id for RPC calls")
127-
except (ValueError, JSONDecodeError) as err:
128-
_LOGGER.exception("Error getting request_id from message: %s", err)
129-
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
130-
131-
future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)
132115
try:
133116
encoded_msg = self._encoder(message)
117+
except Exception as err:
118+
_LOGGER.exception("Error encoding MQTT message: %s", err)
119+
raise RoborockException(f"Failed to encode MQTT message: {err}") from err
120+
try:
134121
self._transport.write(encoded_msg)
135-
return await asyncio.wait_for(future, timeout=timeout)
136-
except asyncio.TimeoutError as ex:
137-
await self._pending_rpcs.pop(request_id)
138-
raise RoborockException(f"Command timed out after {timeout}s") from ex
139-
except Exception:
122+
except Exception as err:
140123
logging.exception("Uncaught error sending command")
141-
await self._pending_rpcs.pop(request_id)
142-
raise
124+
raise RoborockException(f"Failed to send message: {message}") from err
143125

144126

145127
# This module provides a factory function to create LocalChannel instances.

roborock/devices/mqtt_channel.py

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
"""Modules for communicating with specific Roborock devices over MQTT."""
22

3-
import asyncio
43
import logging
54
from collections.abc import Callable
6-
from json import JSONDecodeError
75

86
from roborock.containers import HomeDataDevice, RRiot, UserData
97
from roborock.exceptions import RoborockException
10-
from roborock.mqtt.session import MqttParams, MqttSession
8+
from roborock.mqtt.session import MqttParams, MqttSession, MqttSessionException
119
from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
1210
from roborock.roborock_message import RoborockMessage
1311

1412
from .channel import Channel
15-
from .pending import PendingRpcs
1613

1714
_LOGGER = logging.getLogger(__name__)
1815

@@ -31,16 +28,16 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot:
3128
self._rriot = rriot
3229
self._mqtt_params = mqtt_params
3330

34-
# RPC support
35-
self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs()
3631
self._decoder = create_mqtt_decoder(local_key)
3732
self._encoder = create_mqtt_encoder(local_key)
38-
self._mqtt_unsub: Callable[[], None] | None = None
3933

4034
@property
4135
def is_connected(self) -> bool:
42-
"""Return true if the channel is connected."""
43-
return (self._mqtt_unsub is not None) and self._mqtt_session.connected
36+
"""Return true if the channel is connected.
37+
38+
This passes through the underlying MQTT session's connected state.
39+
"""
40+
return self._mqtt_session.connected
4441

4542
@property
4643
def _publish_topic(self) -> str:
@@ -57,9 +54,6 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
5754
5855
The callback will be called with the message payload when a message is received.
5956
60-
All messages received will be processed through the provided callback, even
61-
those sent in response to the `send_command` command.
62-
6357
Returns a callable that can be used to unsubscribe from the topic.
6458
"""
6559

@@ -69,55 +63,29 @@ def message_handler(payload: bytes) -> None:
6963
return
7064
for message in messages:
7165
_LOGGER.debug("Received message: %s", message)
72-
asyncio.create_task(self._resolve_future_with_lock(message))
7366
try:
7467
callback(message)
7568
except Exception as e:
7669
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
7770

78-
self._mqtt_unsub = await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
79-
80-
def unsub_wrapper() -> None:
81-
if self._mqtt_unsub is not None:
82-
self._mqtt_unsub()
83-
self._mqtt_unsub = None
84-
85-
return unsub_wrapper
86-
87-
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
88-
"""Resolve waiting future with proper locking."""
89-
if (request_id := message.get_request_id()) is None:
90-
_LOGGER.debug("Received message with no request_id")
91-
return
92-
await self._pending_rpcs.resolve(request_id, message)
71+
return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
9372

94-
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
95-
"""Send a command message and wait for the response message.
73+
async def publish(self, message: RoborockMessage) -> None:
74+
"""Publish a command message.
9675
97-
Returns the raw response message - caller is responsible for parsing.
76+
The caller is responsible for handling any responses and associating them
77+
with the incoming request.
9878
"""
99-
try:
100-
if (request_id := message.get_request_id()) is None:
101-
raise RoborockException("Message must have a request_id for RPC calls")
102-
except (ValueError, JSONDecodeError) as err:
103-
_LOGGER.exception("Error getting request_id from message: %s", err)
104-
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
105-
106-
future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)
107-
10879
try:
10980
encoded_msg = self._encoder(message)
110-
await self._mqtt_session.publish(self._publish_topic, encoded_msg)
111-
112-
return await asyncio.wait_for(future, timeout=timeout)
113-
114-
except asyncio.TimeoutError as ex:
115-
await self._pending_rpcs.pop(request_id)
116-
raise RoborockException(f"Command timed out after {timeout}s") from ex
117-
except Exception:
118-
logging.exception("Uncaught error sending command")
119-
await self._pending_rpcs.pop(request_id)
120-
raise
81+
except Exception as e:
82+
_LOGGER.exception("Error encoding MQTT message: %s", e)
83+
raise RoborockException(f"Failed to encode MQTT message: {e}") from e
84+
try:
85+
return await self._mqtt_session.publish(self._publish_topic, encoded_msg)
86+
except MqttSessionException as e:
87+
_LOGGER.exception("Error publishing MQTT message: %s", e)
88+
raise RoborockException(f"Failed to publish MQTT message: {e}") from e
12189

12290

12391
def create_mqtt_channel(

roborock/devices/pending.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

roborock/devices/traits/b01/props.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Any
54

65
from roborock import RoborockB01Methods
76
from roborock.roborock_message import RoborockB01Props
@@ -26,6 +25,6 @@ def __init__(self, channel: MqttChannel) -> None:
2625
"""Initialize the B01Props API."""
2726
self._channel = channel
2827

29-
async def query_values(self, props: list[RoborockB01Props]) -> dict[int, Any]:
28+
async def query_values(self, props: list[RoborockB01Props]) -> None:
3029
"""Query the device for the values of the given Dyad protocols."""
31-
return await send_decoded_command(self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params=props)
30+
await send_decoded_command(self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params=props)

roborock/devices/v1_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def is_local_connected(self) -> bool:
7979
@property
8080
def is_mqtt_connected(self) -> bool:
8181
"""Return whether MQTT connection is available."""
82-
return self._mqtt_unsub is not None
82+
return self._mqtt_unsub is not None and self._mqtt_channel.is_connected
8383

8484
@property
8585
def rpc_channel(self) -> V1RpcChannel:

0 commit comments

Comments
 (0)