Skip to content

Commit ea1587b

Browse files
committed
feat: fix a01 and b01 response handling in new api
1 parent 698707f commit ea1587b

15 files changed

+303
-497
lines changed

roborock/devices/a01_channel.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
"""Thin wrapper around the MQTT channel for Roborock A01 devices."""
22

3-
from __future__ import annotations
4-
3+
import asyncio
54
import logging
65
from typing import Any, overload
76

7+
from roborock.exceptions import RoborockException
88
from roborock.protocols.a01_protocol import (
99
decode_rpc_response,
1010
encode_mqtt_payload,
1111
)
12-
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
12+
from roborock.roborock_message import (
13+
RoborockDyadDataProtocol,
14+
RoborockMessage,
15+
RoborockZeoProtocol,
16+
)
1317

1418
from .mqtt_channel import MqttChannel
1519

1620
_LOGGER = logging.getLogger(__name__)
21+
_TIMEOUT = 10.0
1722

1823

1924
@overload
@@ -39,5 +44,46 @@ async def send_decoded_command(
3944
"""Send a command on the MQTT channel and get a decoded response."""
4045
_LOGGER.debug("Sending MQTT command: %s", params)
4146
roborock_message = encode_mqtt_payload(params)
42-
response = await mqtt_channel.send_message(roborock_message)
43-
return decode_rpc_response(response) # type: ignore[return-value]
47+
48+
# We only block on a response for queries
49+
param_values = {int(k): v for k, v in params.items()}
50+
if not (
51+
query_values := param_values.get(int(RoborockDyadDataProtocol.ID_QUERY))
52+
or param_values.get(int(RoborockZeoProtocol.ID_QUERY))
53+
):
54+
await mqtt_channel.publish(roborock_message)
55+
return {}
56+
57+
# This can be simplified if we can assume a all results are returned in
58+
# single response. Otherwise, this will construct a result by merging in
59+
# responses that contain the ids that were queried.
60+
finished = asyncio.Event()
61+
result: dict[int, Any] = {}
62+
63+
def find_response(response_message: RoborockMessage) -> None:
64+
"""Handle incoming messages and resolve the future."""
65+
try:
66+
decoded = decode_rpc_response(response_message)
67+
except RoborockException:
68+
return
69+
for key, value in decoded.items():
70+
if key in query_values:
71+
result[key] = value
72+
if len(result) != len(query_values):
73+
return
74+
_LOGGER.debug("Received query response: %s", result)
75+
if not finished.is_set():
76+
finished.set()
77+
78+
unsub = await mqtt_channel.subscribe(find_response)
79+
80+
try:
81+
await mqtt_channel.publish(roborock_message)
82+
try:
83+
await asyncio.wait_for(finished.wait(), timeout=_TIMEOUT)
84+
except TimeoutError as ex:
85+
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
86+
finally:
87+
unsub()
88+
89+
return result # type: ignore[return-value]

roborock/devices/b01_channel.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import Any
76

87
from roborock.protocols.b01_protocol import (
98
CommandType,
109
ParamsType,
11-
decode_rpc_response,
1210
encode_mqtt_payload,
1311
)
1412

@@ -22,9 +20,8 @@ async def send_decoded_command(
2220
dps: int,
2321
command: CommandType,
2422
params: ParamsType,
25-
) -> dict[int, Any]:
23+
) -> None:
2624
"""Send a command on the MQTT channel and get a decoded response."""
2725
_LOGGER.debug("Sending MQTT command: %s", params)
2826
roborock_message = encode_mqtt_payload(dps, command, params)
29-
response = await mqtt_channel.send_message(roborock_message)
30-
return decode_rpc_response(response) # type: ignore[return-value]
27+
await mqtt_channel.publish(roborock_message)

roborock/devices/local_channel.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44
import logging
55
from collections.abc import Callable
66
from dataclasses import dataclass
7-
from json import JSONDecodeError
87

98
from roborock.exceptions import RoborockConnectionException, RoborockException
109
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
1110
from roborock.roborock_message import RoborockMessage
1211

1312
from .channel import Channel
14-
from .pending import PendingRpcs
1513

1614
_LOGGER = logging.getLogger(__name__)
1715
_PORT = 58867
@@ -47,8 +45,6 @@ def __init__(self, host: str, local_key: str):
4745
self._subscribers: list[Callable[[RoborockMessage], None]] = []
4846
self._is_connected = False
4947

50-
# RPC support
51-
self._pending_rpcs: PendingRpcs[int, RoborockMessage] = PendingRpcs()
5248
self._decoder: Decoder = create_local_decoder(local_key)
5349
self._encoder: Encoder = create_local_encoder(local_key)
5450

@@ -87,7 +83,6 @@ def _data_received(self, data: bytes) -> None:
8783
return
8884
for message in messages:
8985
_LOGGER.debug("Received message: %s", message)
90-
asyncio.create_task(self._resolve_future_with_lock(message))
9186
for callback in self._subscribers:
9287
try:
9388
callback(message)
@@ -109,37 +104,24 @@ def unsubscribe() -> None:
109104

110105
return unsubscribe
111106

112-
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
113-
"""Resolve waiting future with proper locking."""
114-
if (request_id := message.get_request_id()) is None:
115-
_LOGGER.debug("Received message with no request_id")
116-
return
117-
await self._pending_rpcs.resolve(request_id, message)
107+
async def publish(self, message: RoborockMessage) -> None:
108+
"""Send a command message.
118109
119-
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
120-
"""Send a command message and wait for the response message."""
110+
The caller is responsible for associating the message with its response.
111+
"""
121112
if not self._transport or not self._is_connected:
122113
raise RoborockConnectionException("Not connected to device")
123114

124-
try:
125-
if (request_id := message.get_request_id()) is None:
126-
raise RoborockException("Message must have a request_id for RPC calls")
127-
except (ValueError, JSONDecodeError) as err:
128-
_LOGGER.exception("Error getting request_id from message: %s", err)
129-
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
130-
131-
future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)
132115
try:
133116
encoded_msg = self._encoder(message)
117+
except Exception as err:
118+
_LOGGER.exception("Error encoding MQTT message: %s", err)
119+
raise RoborockException(f"Failed to encode MQTT message: {err}") from err
120+
try:
134121
self._transport.write(encoded_msg)
135-
return await asyncio.wait_for(future, timeout=timeout)
136-
except asyncio.TimeoutError as ex:
137-
await self._pending_rpcs.pop(request_id)
138-
raise RoborockException(f"Command timed out after {timeout}s") from ex
139-
except Exception:
122+
except Exception as err:
140123
logging.exception("Uncaught error sending command")
141-
await self._pending_rpcs.pop(request_id)
142-
raise
124+
raise RoborockException(f"Failed to send message: {message}") from err
143125

144126

145127
# This module provides a factory function to create LocalChannel instances.

roborock/devices/mqtt_channel.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""Modules for communicating with specific Roborock devices over MQTT."""
22

3-
import asyncio
43
import logging
54
from collections.abc import Callable
6-
from json import JSONDecodeError
75

86
from roborock.containers import HomeDataDevice, RRiot, UserData
97
from roborock.exceptions import RoborockException
10-
from roborock.mqtt.session import MqttParams, MqttSession
8+
from roborock.mqtt.session import MqttParams, MqttSession, MqttSessionException
119
from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
1210
from roborock.roborock_message import RoborockMessage
1311

@@ -69,7 +67,6 @@ def message_handler(payload: bytes) -> None:
6967
return
7068
for message in messages:
7169
_LOGGER.debug("Received message: %s", message)
72-
asyncio.create_task(self._resolve_future_with_lock(message))
7370
try:
7471
callback(message)
7572
except Exception as e:
@@ -84,40 +81,22 @@ def unsub_wrapper() -> None:
8481

8582
return unsub_wrapper
8683

87-
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
88-
"""Resolve waiting future with proper locking."""
89-
if (request_id := message.get_request_id()) is None:
90-
_LOGGER.debug("Received message with no request_id")
91-
return
92-
await self._pending_rpcs.resolve(request_id, message)
84+
async def publish(self, message: RoborockMessage) -> None:
85+
"""Publish a command message.
9386
94-
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
95-
"""Send a command message and wait for the response message.
96-
97-
Returns the raw response message - caller is responsible for parsing.
87+
The caller is responsible for handling any responses and associating them
88+
with the incoming request.
9889
"""
99-
try:
100-
if (request_id := message.get_request_id()) is None:
101-
raise RoborockException("Message must have a request_id for RPC calls")
102-
except (ValueError, JSONDecodeError) as err:
103-
_LOGGER.exception("Error getting request_id from message: %s", err)
104-
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
105-
106-
future: asyncio.Future[RoborockMessage] = await self._pending_rpcs.start(request_id)
107-
10890
try:
10991
encoded_msg = self._encoder(message)
110-
await self._mqtt_session.publish(self._publish_topic, encoded_msg)
111-
112-
return await asyncio.wait_for(future, timeout=timeout)
113-
114-
except asyncio.TimeoutError as ex:
115-
await self._pending_rpcs.pop(request_id)
116-
raise RoborockException(f"Command timed out after {timeout}s") from ex
117-
except Exception:
118-
logging.exception("Uncaught error sending command")
119-
await self._pending_rpcs.pop(request_id)
120-
raise
92+
except Exception as e:
93+
_LOGGER.exception("Error encoding MQTT message: %s", e)
94+
raise RoborockException(f"Failed to encode MQTT message: {e}") from e
95+
try:
96+
return await self._mqtt_session.publish(self._publish_topic, encoded_msg)
97+
except MqttSessionException as e:
98+
_LOGGER.exception("Error publishing MQTT message: %s", e)
99+
raise RoborockException(f"Failed to publish MQTT message: {e}") from e
121100

122101

123102
def create_mqtt_channel(

roborock/devices/v1_rpc_channel.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
and local connections, preferring local when available.
77
"""
88

9+
import asyncio
910
import logging
1011
from collections.abc import Callable
1112
from typing import Any, Protocol, TypeVar, overload
1213

1314
from roborock.containers import RoborockBase
15+
from roborock.exceptions import RoborockException
1416
from roborock.protocols.v1_protocol import (
1517
CommandType,
1618
ParamsType,
@@ -24,6 +26,7 @@
2426
from .mqtt_channel import MqttChannel
2527

2628
_LOGGER = logging.getLogger(__name__)
29+
_TIMEOUT = 10.0
2730

2831

2932
_T = TypeVar("_T", bound=RoborockBase)
@@ -132,8 +135,26 @@ async def _send_raw_command(
132135
_LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
133136
request_message = RequestMessage(method, params=params)
134137
message = self._payload_encoder(request_message)
135-
response = await self._channel.send_message(message)
136-
return decode_rpc_response(response)
138+
139+
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
140+
141+
def find_response(response_message: RoborockMessage) -> None:
142+
try:
143+
decoded = decode_rpc_response(response_message)
144+
except RoborockException:
145+
return
146+
if decoded.request_id == request_message.request_id:
147+
future.set_result(decoded.data)
148+
149+
unsub = await self._channel.subscribe(find_response)
150+
try:
151+
await self._channel.publish(message)
152+
return await asyncio.wait_for(future, timeout=_TIMEOUT)
153+
except TimeoutError as ex:
154+
future.cancel()
155+
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
156+
finally:
157+
unsub()
137158

138159

139160
def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:

roborock/protocols/v1_protocol.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,18 @@ def _as_payload(self, security_data: SecurityData | None) -> bytes:
9595
)
9696

9797

98-
def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
98+
@dataclass(kw_only=True, frozen=True)
99+
class ResponseMessage:
100+
"""Data structure for v1 RoborockMessage responses."""
101+
102+
request_id: int | None
103+
"""The request ID of the response."""
104+
105+
data: dict[str, Any]
106+
"""The data of the response."""
107+
108+
109+
def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
99110
"""Decode a V1 RPC_RESPONSE message."""
100111
if not message.payload:
101112
raise RoborockException("Invalid V1 message format: missing payload")
@@ -109,14 +120,19 @@ def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
109120
if not isinstance(datapoints, dict):
110121
raise RoborockException(f"Invalid V1 message format: 'dps' should be a dictionary for {message.payload!r}")
111122

112-
if not (data_point := datapoints.get("102")):
113-
raise RoborockException("Invalid V1 message format: missing '102' data point")
123+
if not (data_point := datapoints.get(str(RoborockMessageProtocol.RPC_RESPONSE))):
124+
raise RoborockException(
125+
f"Invalid V1 message format: missing '{RoborockMessageProtocol.RPC_RESPONSE}' data point"
126+
)
114127

115128
try:
116129
data_point_response = json.loads(data_point)
117130
except (json.JSONDecodeError, TypeError) as e:
118-
raise RoborockException(f"Invalid V1 message data point '102': {e} for {message.payload!r}") from e
131+
raise RoborockException(
132+
f"Invalid V1 message data point '{RoborockMessageProtocol.RPC_RESPONSE}': {e} for {message.payload!r}"
133+
) from e
119134

135+
request_id: int | None = data_point_response.get("id")
120136
if error := data_point_response.get("error"):
121137
raise RoborockException(f"Error in message: {error}")
122138

@@ -127,7 +143,7 @@ def decode_rpc_response(message: RoborockMessage) -> dict[str, Any]:
127143
result = result[0]
128144
if not isinstance(result, dict):
129145
raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}")
130-
return result
146+
return ResponseMessage(request_id=request_id, data=result)
131147

132148

133149
@dataclass

roborock/roborock_message.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import json
43
import math
54
import time
65
from dataclasses import dataclass, field
@@ -247,12 +246,3 @@ class RoborockMessage:
247246
version: bytes = b"1.0"
248247
random: int = field(default_factory=lambda: get_next_int(10000, 99999))
249248
timestamp: int = field(default_factory=lambda: math.floor(time.time()))
250-
251-
def get_request_id(self) -> int | None:
252-
if self.payload:
253-
payload = json.loads(self.payload.decode())
254-
for data_point_number, data_point in payload.get("dps").items():
255-
if data_point_number in ["101", "102"]:
256-
data_point_response = json.loads(data_point)
257-
return data_point_response.get("id")
258-
return None

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ async def _send_command(
4949
RoborockMessageProtocol.RPC_REQUEST,
5050
security_data=self._security_data,
5151
)
52-
self._logger.debug("Building message id %s for method %s", roborock_message.get_request_id, method)
52+
self._logger.debug("Building message id %s for method %s", request_message.request_id, method)
5353

5454
await self.validate_connection()
5555
request_id = request_message.request_id

0 commit comments

Comments
 (0)