Skip to content

Commit a02b3bd

Browse files
committed
fix: Fix exception when sending dyad/zeo requests
The bug was introduced in #645. Add tests that exercise the actual request encoding. This changes the ID QUERY value encoding by passing in a function, which is another variation on the first version of #645 where the json encoding happened inside the decode function.
1 parent 5193ef4 commit a02b3bd

File tree

8 files changed

+200
-111
lines changed

8 files changed

+200
-111
lines changed

roborock/devices/a01_channel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import logging
5+
from collections.abc import Callable
56
from typing import Any, overload
67

78
from roborock.exceptions import RoborockException
@@ -29,23 +30,26 @@
2930
async def send_decoded_command(
3031
mqtt_channel: MqttChannel,
3132
params: dict[RoborockDyadDataProtocol, Any],
33+
value_encoder: Callable[[Any], Any] | None = None,
3234
) -> dict[RoborockDyadDataProtocol, Any]: ...
3335

3436

3537
@overload
3638
async def send_decoded_command(
3739
mqtt_channel: MqttChannel,
3840
params: dict[RoborockZeoProtocol, Any],
41+
value_encoder: Callable[[Any], Any] | None = None,
3942
) -> dict[RoborockZeoProtocol, Any]: ...
4043

4144

4245
async def send_decoded_command(
4346
mqtt_channel: MqttChannel,
4447
params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any],
48+
value_encoder: Callable[[Any], Any] | None = None,
4549
) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]:
4650
"""Send a command on the MQTT channel and get a decoded response."""
4751
_LOGGER.debug("Sending MQTT command: %s", params)
48-
roborock_message = encode_mqtt_payload(params)
52+
roborock_message = encode_mqtt_payload(params, value_encoder)
4953

5054
# For commands that set values: send the command and do not
5155
# block waiting for a response. Queries are handled below.

roborock/devices/traits/a01/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from collections.abc import Callable
23
from datetime import time
34
from typing import Any
@@ -121,8 +122,11 @@ def __init__(self, channel: MqttChannel) -> None:
121122

122123
async def query_values(self, protocols: list[RoborockDyadDataProtocol]) -> dict[RoborockDyadDataProtocol, Any]:
123124
"""Query the device for the values of the given Dyad protocols."""
124-
params = {RoborockDyadDataProtocol.ID_QUERY: str([int(p) for p in protocols])}
125-
response = await send_decoded_command(self._channel, params)
125+
response = await send_decoded_command(
126+
self._channel,
127+
{RoborockDyadDataProtocol.ID_QUERY: protocols},
128+
value_encoder=json.dumps,
129+
)
126130
return {protocol: convert_dyad_value(protocol, response.get(protocol)) for protocol in protocols}
127131

128132
async def set_value(self, protocol: RoborockDyadDataProtocol, value: Any) -> dict[RoborockDyadDataProtocol, Any]:
@@ -142,14 +146,17 @@ def __init__(self, channel: MqttChannel) -> None:
142146

143147
async def query_values(self, protocols: list[RoborockZeoProtocol]) -> dict[RoborockZeoProtocol, Any]:
144148
"""Query the device for the values of the given protocols."""
145-
params = {RoborockZeoProtocol.ID_QUERY: str([int(p) for p in protocols])}
146-
response = await send_decoded_command(self._channel, params)
149+
response = await send_decoded_command(
150+
self._channel,
151+
{RoborockZeoProtocol.ID_QUERY: protocols},
152+
value_encoder=json.dumps,
153+
)
147154
return {protocol: convert_zeo_value(protocol, response.get(protocol)) for protocol in protocols}
148155

149156
async def set_value(self, protocol: RoborockZeoProtocol, value: Any) -> dict[RoborockZeoProtocol, Any]:
150157
"""Set a value for a specific protocol on the device."""
151158
params = {protocol: value}
152-
return await send_decoded_command(self._channel, params)
159+
return await send_decoded_command(self._channel, params, value_encoder=lambda x: x)
153160

154161

155162
def create(product: HomeDataProduct, mqtt_channel: MqttChannel) -> DyadApi | ZeoApi:

roborock/protocols/a01_protocol.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import logging
5+
from collections.abc import Callable
56
from typing import Any
67

78
from Crypto.Cipher import AES
@@ -20,13 +21,28 @@
2021
A01_VERSION = b"A01"
2122

2223

24+
def _no_encode(value: Any) -> Any:
25+
return value
26+
27+
2328
def encode_mqtt_payload(
2429
data: dict[RoborockDyadDataProtocol, Any]
2530
| dict[RoborockZeoProtocol, Any]
2631
| dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any],
32+
value_encoder: Callable[[Any], Any] | None = None,
2733
) -> RoborockMessage:
28-
"""Encode payload for A01 commands over MQTT."""
29-
dps_data = {"dps": data}
34+
"""Encode payload for A01 commands over MQTT.
35+
36+
Args:
37+
data: The data to encode.
38+
value_encoder: A function to encode the values of the dictionary.
39+
40+
Returns:
41+
RoborockMessage: The encoded message.
42+
"""
43+
if value_encoder is None:
44+
value_encoder = _no_encode
45+
dps_data = {"dps": {key: value_encoder(value) for key, value in data.items()}}
3046
payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size)
3147
return RoborockMessage(
3248
protocol=RoborockMessageProtocol.RPC_REQUEST,

tests/devices/test_a01_channel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ async def test_id_query(mock_mqtt_channel: FakeChannel):
3434
{
3535
RoborockDyadDataProtocol.WARM_LEVEL: 101,
3636
RoborockDyadDataProtocol.POWER: 75,
37-
}
37+
},
38+
value_encoder=lambda x: x,
3839
)
3940
response_message = RoborockMessage(
4041
protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=encoded.payload, version=encoded.version

tests/devices/traits/a01/test_init.py

Lines changed: 102 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,51 @@
11
import datetime
2-
from collections.abc import Generator
2+
import json
33
from typing import Any
4-
from unittest.mock import AsyncMock, call, patch
54

65
import pytest
6+
from Crypto.Cipher import AES
7+
from Crypto.Util.Padding import unpad
78

8-
from roborock.devices.mqtt_channel import MqttChannel
99
from roborock.devices.traits.a01 import DyadApi, ZeoApi
10-
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
10+
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockMessageProtocol, RoborockZeoProtocol
11+
from tests.conftest import FakeChannel
12+
from tests.protocols.common import build_a01_message
1113

1214

13-
@pytest.fixture(name="mock_channel")
14-
def mock_channel_fixture() -> AsyncMock:
15-
return AsyncMock(spec=MqttChannel)
15+
@pytest.fixture(name="fake_channel")
16+
def fake_channel_fixture() -> FakeChannel:
17+
return FakeChannel()
1618

1719

18-
@pytest.fixture(name="mock_send")
19-
def mock_send_fixture(mock_channel) -> Generator[AsyncMock, None, None]:
20-
with patch("roborock.devices.traits.a01.send_decoded_command") as mock_send:
21-
yield mock_send
20+
@pytest.fixture(name="dyad_api")
21+
def dyad_api_fixture(fake_channel: FakeChannel) -> DyadApi:
22+
return DyadApi(fake_channel) # type: ignore[arg-type]
2223

2324

24-
async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMock):
25+
@pytest.fixture(name="zeo_api")
26+
def zeo_api_fixture(fake_channel: FakeChannel) -> ZeoApi:
27+
return ZeoApi(fake_channel) # type: ignore[arg-type]
28+
29+
30+
async def test_dyad_api_query_values(dyad_api: DyadApi, fake_channel: FakeChannel):
2531
"""Test that DyadApi currently returns raw values without conversion."""
26-
api = DyadApi(mock_channel)
27-
28-
mock_send.return_value = {
29-
209: 1, # POWER
30-
201: 6, # STATUS
31-
207: 3, # WATER_LEVEL
32-
214: 120, # MESH_LEFT
33-
215: 90, # BRUSH_LEFT
34-
227: 85, # SILENT_MODE_START_TIME
35-
229: "3,4,5", # RECENT_RUN_TIME
36-
230: 123456, # TOTAL_RUN_TIME
37-
222: 1, # STAND_LOCK_AUTO_RUN
38-
224: 0, # AUTO_DRY_MODE
39-
}
40-
result = await api.query_values(
32+
fake_channel.response_queue.append(
33+
build_a01_message(
34+
{
35+
209: 1, # POWER
36+
201: 6, # STATUS
37+
207: 3, # WATER_LEVEL
38+
214: 120, # MESH_LEFT
39+
215: 90, # BRUSH_LEFT
40+
227: 85, # SILENT_MODE_START_TIME
41+
229: "3,4,5", # RECENT_RUN_TIME
42+
230: 123456, # TOTAL_RUN_TIME
43+
222: 1, # STAND_LOCK_AUTO_RUN
44+
224: 0, # AUTO_DRY_MODE
45+
}
46+
)
47+
)
48+
result = await dyad_api.query_values(
4149
[
4250
RoborockDyadDataProtocol.POWER,
4351
RoborockDyadDataProtocol.STATUS,
@@ -64,15 +72,12 @@ async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMo
6472
RoborockDyadDataProtocol.AUTO_DRY_MODE: False,
6573
}
6674

67-
# Note: Bug here, this is the wrong encoding for the query
68-
assert mock_send.call_args_list == [
69-
call(
70-
mock_channel,
71-
{
72-
RoborockDyadDataProtocol.ID_QUERY: "[209, 201, 207, 214, 215, 227, 229, 230, 222, 224]",
73-
},
74-
),
75-
]
75+
assert len(fake_channel.published_messages) == 1
76+
message = fake_channel.published_messages[0]
77+
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
78+
assert message.version == b"A01"
79+
payload_data = json.loads(unpad(message.payload, AES.block_size))
80+
assert payload_data == {"dps": {"10000": "[209, 201, 207, 214, 215, 227, 229, 230, 222, 224]"}}
7681

7782

7883
@pytest.mark.parametrize(
@@ -117,33 +122,34 @@ async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMo
117122
],
118123
)
119124
async def test_dyad_invalid_response_value(
120-
mock_channel: AsyncMock,
121-
mock_send: AsyncMock,
122125
query: list[RoborockDyadDataProtocol],
123126
response: dict[int, Any],
124127
expected_result: dict[RoborockDyadDataProtocol, Any],
128+
dyad_api: DyadApi,
129+
fake_channel: FakeChannel,
125130
):
126131
"""Test that DyadApi currently returns raw values without conversion."""
127-
api = DyadApi(mock_channel)
132+
fake_channel.response_queue.append(build_a01_message(response))
128133

129-
mock_send.return_value = response
130-
result = await api.query_values(query)
134+
result = await dyad_api.query_values(query)
131135
assert result == expected_result
132136

133137

134-
async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMock):
138+
async def test_zeo_api_query_values(zeo_api: ZeoApi, fake_channel: FakeChannel):
135139
"""Test that ZeoApi currently returns raw values without conversion."""
136-
api = ZeoApi(mock_channel)
137-
138-
mock_send.return_value = {
139-
203: 6, # spinning
140-
207: 3, # medium
141-
226: 1,
142-
227: 0,
143-
224: 1, # Times after clean. Testing int value
144-
218: 0, # Washing left. Testing zero int value
145-
}
146-
result = await api.query_values(
140+
fake_channel.response_queue.append(
141+
build_a01_message(
142+
{
143+
203: 6, # spinning
144+
207: 3, # medium
145+
226: 1,
146+
227: 0,
147+
224: 1, # Times after clean. Testing int value
148+
218: 0, # Washing left. Testing zero int value
149+
}
150+
)
151+
)
152+
result = await zeo_api.query_values(
147153
[
148154
RoborockZeoProtocol.STATE,
149155
RoborockZeoProtocol.TEMP,
@@ -162,15 +168,13 @@ async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMoc
162168
RoborockZeoProtocol.TIMES_AFTER_CLEAN: 1,
163169
RoborockZeoProtocol.WASHING_LEFT: 0,
164170
}
165-
# Note: Bug here, this is the wrong encoding for the query
166-
assert mock_send.call_args_list == [
167-
call(
168-
mock_channel,
169-
{
170-
RoborockZeoProtocol.ID_QUERY: "[203, 207, 226, 227, 224, 218]",
171-
},
172-
),
173-
]
171+
172+
assert len(fake_channel.published_messages) == 1
173+
message = fake_channel.published_messages[0]
174+
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
175+
assert message.version == b"A01"
176+
payload_data = json.loads(unpad(message.payload, AES.block_size))
177+
assert payload_data == {"dps": {"10000": "[203, 207, 226, 227, 224, 218]"}}
174178

175179

176180
@pytest.mark.parametrize(
@@ -215,15 +219,46 @@ async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMoc
215219
],
216220
)
217221
async def test_zeo_invalid_response_value(
218-
mock_channel: AsyncMock,
219-
mock_send: AsyncMock,
220222
query: list[RoborockZeoProtocol],
221223
response: dict[int, Any],
222224
expected_result: dict[RoborockZeoProtocol, Any],
225+
zeo_api: ZeoApi,
226+
fake_channel: FakeChannel,
223227
):
224228
"""Test that ZeoApi currently returns raw values without conversion."""
225-
api = ZeoApi(mock_channel)
229+
fake_channel.response_queue.append(build_a01_message(response))
226230

227-
mock_send.return_value = response
228-
result = await api.query_values(query)
231+
result = await zeo_api.query_values(query)
229232
assert result == expected_result
233+
234+
235+
async def test_dyad_api_set_value(dyad_api: DyadApi, fake_channel: FakeChannel):
236+
"""Test DyadApi set_value sends correct command."""
237+
await dyad_api.set_value(RoborockDyadDataProtocol.POWER, 1)
238+
239+
assert len(fake_channel.published_messages) == 1
240+
message = fake_channel.published_messages[0]
241+
242+
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
243+
assert message.version == b"A01"
244+
245+
# decode the payload to verify contents
246+
payload_data = json.loads(unpad(message.payload, AES.block_size))
247+
# A01 protocol expects values to be strings in the dps dict
248+
assert payload_data == {"dps": {"209": 1}}
249+
250+
251+
async def test_zeo_api_set_value(zeo_api: ZeoApi, fake_channel: FakeChannel):
252+
"""Test ZeoApi set_value sends correct command."""
253+
await zeo_api.set_value(RoborockZeoProtocol.MODE, "standard")
254+
255+
assert len(fake_channel.published_messages) == 1
256+
message = fake_channel.published_messages[0]
257+
258+
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
259+
assert message.version == b"A01"
260+
261+
# decode the payload to verify contents
262+
payload_data = json.loads(unpad(message.payload, AES.block_size))
263+
# A01 protocol expects values to be strings in the dps dict
264+
assert payload_data == {"dps": {"204": "standard"}}

tests/protocols/common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Common test utils for the protocols package."""
2+
3+
import json
4+
from typing import Any
5+
6+
from Crypto.Cipher import AES
7+
from Crypto.Util.Padding import pad
8+
9+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
10+
11+
12+
def build_a01_message(message: dict[Any, Any], seq: int = 2020) -> RoborockMessage:
13+
"""Build an encoded A01 RPC response message."""
14+
return RoborockMessage(
15+
protocol=RoborockMessageProtocol.RPC_RESPONSE,
16+
payload=pad(
17+
json.dumps(
18+
{
19+
"dps": message, # {10000: json.dumps(message)},
20+
}
21+
).encode(),
22+
AES.block_size,
23+
),
24+
version=b"A01",
25+
seq=seq,
26+
)

0 commit comments

Comments
 (0)