11import datetime
2- from collections . abc import Generator
2+ import json
33from typing import Any
4- from unittest .mock import AsyncMock , call , patch
54
65import pytest
6+ from Crypto .Cipher import AES
7+ from Crypto .Util .Padding import unpad
78
8- from roborock .devices .mqtt_channel import MqttChannel
99from roborock .devices .traits .a01 import DyadApi , ZeoApi
10- from roborock .roborock_message import RoborockDyadDataProtocol , RoborockZeoProtocol
10+ from roborock .roborock_message import RoborockDyadDataProtocol , RoborockMessageProtocol , RoborockZeoProtocol
11+ from tests .conftest import FakeChannel
12+ from tests .protocols .common import build_a01_message
1113
1214
13- @pytest .fixture (name = "mock_channel " )
14- def mock_channel_fixture () -> AsyncMock :
15- return AsyncMock ( spec = MqttChannel )
15+ @pytest .fixture (name = "fake_channel " )
16+ def fake_channel_fixture () -> FakeChannel :
17+ return FakeChannel ( )
1618
1719
18- @pytest .fixture (name = "mock_send" )
19- def mock_send_fixture (mock_channel ) -> Generator [AsyncMock , None , None ]:
20- with patch ("roborock.devices.traits.a01.send_decoded_command" ) as mock_send :
21- yield mock_send
20+ @pytest .fixture (name = "dyad_api" )
21+ def dyad_api_fixture (fake_channel : FakeChannel ) -> DyadApi :
22+ return DyadApi (fake_channel ) # type: ignore[arg-type]
2223
2324
24- async def test_dyad_api_query_values (mock_channel : AsyncMock , mock_send : AsyncMock ):
25+ @pytest .fixture (name = "zeo_api" )
26+ def zeo_api_fixture (fake_channel : FakeChannel ) -> ZeoApi :
27+ return ZeoApi (fake_channel ) # type: ignore[arg-type]
28+
29+
30+ async def test_dyad_api_query_values (dyad_api : DyadApi , fake_channel : FakeChannel ):
2531 """Test that DyadApi currently returns raw values without conversion."""
26- api = DyadApi (mock_channel )
27-
28- mock_send .return_value = {
29- 209 : 1 , # POWER
30- 201 : 6 , # STATUS
31- 207 : 3 , # WATER_LEVEL
32- 214 : 120 , # MESH_LEFT
33- 215 : 90 , # BRUSH_LEFT
34- 227 : 85 , # SILENT_MODE_START_TIME
35- 229 : "3,4,5" , # RECENT_RUN_TIME
36- 230 : 123456 , # TOTAL_RUN_TIME
37- 222 : 1 , # STAND_LOCK_AUTO_RUN
38- 224 : 0 , # AUTO_DRY_MODE
39- }
40- result = await api .query_values (
32+ fake_channel .response_queue .append (
33+ build_a01_message (
34+ {
35+ 209 : 1 , # POWER
36+ 201 : 6 , # STATUS
37+ 207 : 3 , # WATER_LEVEL
38+ 214 : 120 , # MESH_LEFT
39+ 215 : 90 , # BRUSH_LEFT
40+ 227 : 85 , # SILENT_MODE_START_TIME
41+ 229 : "3,4,5" , # RECENT_RUN_TIME
42+ 230 : 123456 , # TOTAL_RUN_TIME
43+ 222 : 1 , # STAND_LOCK_AUTO_RUN
44+ 224 : 0 , # AUTO_DRY_MODE
45+ }
46+ )
47+ )
48+ result = await dyad_api .query_values (
4149 [
4250 RoborockDyadDataProtocol .POWER ,
4351 RoborockDyadDataProtocol .STATUS ,
@@ -64,15 +72,12 @@ async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMo
6472 RoborockDyadDataProtocol .AUTO_DRY_MODE : False ,
6573 }
6674
67- # Note: Bug here, this is the wrong encoding for the query
68- assert mock_send .call_args_list == [
69- call (
70- mock_channel ,
71- {
72- RoborockDyadDataProtocol .ID_QUERY : "[209, 201, 207, 214, 215, 227, 229, 230, 222, 224]" ,
73- },
74- ),
75- ]
75+ assert len (fake_channel .published_messages ) == 1
76+ message = fake_channel .published_messages [0 ]
77+ assert message .protocol == RoborockMessageProtocol .RPC_REQUEST
78+ assert message .version == b"A01"
79+ payload_data = json .loads (unpad (message .payload , AES .block_size ))
80+ assert payload_data == {"dps" : {"10000" : "[209, 201, 207, 214, 215, 227, 229, 230, 222, 224]" }}
7681
7782
7883@pytest .mark .parametrize (
@@ -117,33 +122,34 @@ async def test_dyad_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMo
117122 ],
118123)
119124async def test_dyad_invalid_response_value (
120- mock_channel : AsyncMock ,
121- mock_send : AsyncMock ,
122125 query : list [RoborockDyadDataProtocol ],
123126 response : dict [int , Any ],
124127 expected_result : dict [RoborockDyadDataProtocol , Any ],
128+ dyad_api : DyadApi ,
129+ fake_channel : FakeChannel ,
125130):
126131 """Test that DyadApi currently returns raw values without conversion."""
127- api = DyadApi ( mock_channel )
132+ fake_channel . response_queue . append ( build_a01_message ( response ) )
128133
129- mock_send .return_value = response
130- result = await api .query_values (query )
134+ result = await dyad_api .query_values (query )
131135 assert result == expected_result
132136
133137
134- async def test_zeo_api_query_values (mock_channel : AsyncMock , mock_send : AsyncMock ):
138+ async def test_zeo_api_query_values (zeo_api : ZeoApi , fake_channel : FakeChannel ):
135139 """Test that ZeoApi currently returns raw values without conversion."""
136- api = ZeoApi (mock_channel )
137-
138- mock_send .return_value = {
139- 203 : 6 , # spinning
140- 207 : 3 , # medium
141- 226 : 1 ,
142- 227 : 0 ,
143- 224 : 1 , # Times after clean. Testing int value
144- 218 : 0 , # Washing left. Testing zero int value
145- }
146- result = await api .query_values (
140+ fake_channel .response_queue .append (
141+ build_a01_message (
142+ {
143+ 203 : 6 , # spinning
144+ 207 : 3 , # medium
145+ 226 : 1 ,
146+ 227 : 0 ,
147+ 224 : 1 , # Times after clean. Testing int value
148+ 218 : 0 , # Washing left. Testing zero int value
149+ }
150+ )
151+ )
152+ result = await zeo_api .query_values (
147153 [
148154 RoborockZeoProtocol .STATE ,
149155 RoborockZeoProtocol .TEMP ,
@@ -162,15 +168,13 @@ async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMoc
162168 RoborockZeoProtocol .TIMES_AFTER_CLEAN : 1 ,
163169 RoborockZeoProtocol .WASHING_LEFT : 0 ,
164170 }
165- # Note: Bug here, this is the wrong encoding for the query
166- assert mock_send .call_args_list == [
167- call (
168- mock_channel ,
169- {
170- RoborockZeoProtocol .ID_QUERY : "[203, 207, 226, 227, 224, 218]" ,
171- },
172- ),
173- ]
171+
172+ assert len (fake_channel .published_messages ) == 1
173+ message = fake_channel .published_messages [0 ]
174+ assert message .protocol == RoborockMessageProtocol .RPC_REQUEST
175+ assert message .version == b"A01"
176+ payload_data = json .loads (unpad (message .payload , AES .block_size ))
177+ assert payload_data == {"dps" : {"10000" : "[203, 207, 226, 227, 224, 218]" }}
174178
175179
176180@pytest .mark .parametrize (
@@ -215,15 +219,46 @@ async def test_zeo_api_query_values(mock_channel: AsyncMock, mock_send: AsyncMoc
215219 ],
216220)
217221async def test_zeo_invalid_response_value (
218- mock_channel : AsyncMock ,
219- mock_send : AsyncMock ,
220222 query : list [RoborockZeoProtocol ],
221223 response : dict [int , Any ],
222224 expected_result : dict [RoborockZeoProtocol , Any ],
225+ zeo_api : ZeoApi ,
226+ fake_channel : FakeChannel ,
223227):
224228 """Test that ZeoApi currently returns raw values without conversion."""
225- api = ZeoApi ( mock_channel )
229+ fake_channel . response_queue . append ( build_a01_message ( response ) )
226230
227- mock_send .return_value = response
228- result = await api .query_values (query )
231+ result = await zeo_api .query_values (query )
229232 assert result == expected_result
233+
234+
235+ async def test_dyad_api_set_value (dyad_api : DyadApi , fake_channel : FakeChannel ):
236+ """Test DyadApi set_value sends correct command."""
237+ await dyad_api .set_value (RoborockDyadDataProtocol .POWER , 1 )
238+
239+ assert len (fake_channel .published_messages ) == 1
240+ message = fake_channel .published_messages [0 ]
241+
242+ assert message .protocol == RoborockMessageProtocol .RPC_REQUEST
243+ assert message .version == b"A01"
244+
245+ # decode the payload to verify contents
246+ payload_data = json .loads (unpad (message .payload , AES .block_size ))
247+ # A01 protocol expects values to be strings in the dps dict
248+ assert payload_data == {"dps" : {"209" : 1 }}
249+
250+
251+ async def test_zeo_api_set_value (zeo_api : ZeoApi , fake_channel : FakeChannel ):
252+ """Test ZeoApi set_value sends correct command."""
253+ await zeo_api .set_value (RoborockZeoProtocol .MODE , "standard" )
254+
255+ assert len (fake_channel .published_messages ) == 1
256+ message = fake_channel .published_messages [0 ]
257+
258+ assert message .protocol == RoborockMessageProtocol .RPC_REQUEST
259+ assert message .version == b"A01"
260+
261+ # decode the payload to verify contents
262+ payload_data = json .loads (unpad (message .payload , AES .block_size ))
263+ # A01 protocol expects values to be strings in the dps dict
264+ assert payload_data == {"dps" : {"204" : "standard" }}
0 commit comments