Skip to content

Commit ab78435

Browse files
committed
chore: add some tests and adress comments
1 parent cd97ddb commit ab78435

File tree

2 files changed

+99
-10
lines changed

2 files changed

+99
-10
lines changed

roborock/devices/local_channel.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,13 @@ def __init__(self, host: str, local_key: str):
6060
self._local_key = local_key
6161
self._local_protocol_version: LocalProtocolVersion | None = None
6262
self._connect_nonce = get_next_int(10000, 32767)
63-
self._ack_nonce: int | None = None
63+
self._params: LocalChannelParams | None = None
6464
self._update_encoder_decoder()
6565

6666
def _update_encoder_decoder(self, params: LocalChannelParams | None = None):
6767
if params is None:
68-
params = LocalChannelParams(
69-
local_key=self._local_key, connect_nonce=self._connect_nonce, ack_nonce=self._ack_nonce
70-
)
68+
params = LocalChannelParams(local_key=self._local_key, connect_nonce=self._connect_nonce, ack_nonce=None)
69+
self._params = params
7170
self._encoder = create_local_encoder(
7271
local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce
7372
)
@@ -91,7 +90,7 @@ async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> Local
9190
seq=1,
9291
)
9392
try:
94-
response = await self.send_message(
93+
response = await self._send_message(
9594
roborock_message=request,
9695
request_id=request.seq,
9796
response_protocol=RoborockMessageProtocol.HELLO_RESPONSE,
@@ -123,7 +122,6 @@ async def _hello(self):
123122
for version in attempt_versions:
124123
params = await self._do_hello(version)
125124
if params is not None:
126-
self._ack_nonce = params.ack_nonce
127125
self._local_protocol_version = version
128126
self._update_encoder_decoder(params)
129127
return
@@ -200,7 +198,7 @@ async def publish(self, message: RoborockMessage) -> None:
200198
logging.exception("Uncaught error sending command")
201199
raise RoborockException(f"Failed to send message: {message}") from err
202200

203-
async def send_message(
201+
async def _send_message(
204202
self,
205203
roborock_message: RoborockMessage,
206204
request_id: int,

tests/devices/test_local_channel.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
from roborock.devices.local_channel import LocalChannel, LocalChannelParams
11-
from roborock.exceptions import RoborockConnectionException
11+
from roborock.exceptions import RoborockConnectionException, RoborockException
1212
from roborock.protocol import create_local_decoder, create_local_encoder
1313
from roborock.protocols.v1_protocol import LocalProtocolVersion
1414
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
@@ -61,7 +61,7 @@ async def setup_local_channel_with_hello_mock() -> LocalChannel:
6161
"""Fixture to set up the local channel with automatic hello mocking."""
6262
channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY)
6363

64-
async def mock_do_hello(local_protocol_version):
64+
async def mock_do_hello(_: LocalProtocolVersion):
6565
"""Mock _do_hello to return successful params without sending actual request."""
6666
return LocalChannelParams(local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=54321)
6767

@@ -269,5 +269,96 @@ async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalCh
269269

270270
# Verify that the channel is using L01 protocol
271271
assert channel._local_protocol_version == LocalProtocolVersion.L01
272-
assert channel._ack_nonce == 54321
272+
assert channel._params is not None
273+
assert channel._params.ack_nonce == 54321
274+
assert channel._is_connected is True
275+
276+
277+
async def test_hello_success_with_v1_protocol_first(mock_loop: Mock, mock_transport: Mock) -> None:
278+
"""Test that when V1 protocol succeeds on first attempt, we use V1."""
279+
280+
# Create a channel without the automatic hello mocking
281+
channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY)
282+
283+
# Mock _do_hello to succeed for V1 on first attempt
284+
async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None:
285+
if local_protocol_version == LocalProtocolVersion.V1:
286+
# V1 succeeds on first attempt
287+
return LocalChannelParams(
288+
local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=67890
289+
)
290+
elif local_protocol_version == LocalProtocolVersion.L01:
291+
# L01 would succeed but we shouldn't reach it
292+
return LocalChannelParams(
293+
local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=99999
294+
)
295+
return None
296+
297+
# Replace the _do_hello method
298+
setattr(channel, "_do_hello", mock_do_hello)
299+
300+
# Connect and verify V1 protocol is used
301+
await channel.connect()
302+
303+
# Verify that the channel is using V1 protocol
304+
assert channel._local_protocol_version == LocalProtocolVersion.V1
305+
assert channel._params is not None
306+
assert channel._params.ack_nonce == 67890
307+
assert channel._is_connected is True
308+
309+
310+
async def test_hello_both_protocols_fail(mock_loop: Mock, mock_transport: Mock) -> None:
311+
"""Test that when both V1 and L01 protocols fail, connection fails."""
312+
313+
# Create a channel without the automatic hello mocking
314+
channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY)
315+
316+
# Mock _do_hello to fail for both protocols
317+
async def mock_do_hello(_: LocalProtocolVersion) -> LocalChannelParams | None:
318+
# Both protocols fail
319+
return None
320+
321+
# Replace the _do_hello method
322+
setattr(channel, "_do_hello", mock_do_hello)
323+
324+
# Connect should raise an exception
325+
with pytest.raises(RoborockException, match="Failed to connect to device with any known protocol"):
326+
await channel.connect()
327+
328+
# Verify that the channel is not connected and cleaned up
329+
assert channel._is_connected is False
330+
assert channel._transport is None
331+
332+
333+
async def test_hello_preferred_protocol_version_ordering(mock_loop: Mock, mock_transport: Mock) -> None:
334+
"""Test that preferred protocol version is tried first."""
335+
336+
# Create a channel with preferred L01 protocol
337+
channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY)
338+
channel._local_protocol_version = LocalProtocolVersion.L01
339+
340+
# Track which protocols were attempted and in what order
341+
attempted_protocols: list[LocalProtocolVersion] = []
342+
343+
# Mock _do_hello to track attempts and succeed on L01
344+
async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None:
345+
attempted_protocols.append(local_protocol_version)
346+
if local_protocol_version == LocalProtocolVersion.L01:
347+
# L01 succeeds
348+
return LocalChannelParams(
349+
local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=11111
350+
)
351+
return None
352+
353+
# Replace the _do_hello method
354+
setattr(channel, "_do_hello", mock_do_hello)
355+
356+
# Connect and verify L01 is tried first
357+
await channel.connect()
358+
359+
# Verify that L01 was tried first (preferred version)
360+
assert attempted_protocols == [LocalProtocolVersion.L01]
361+
assert channel._local_protocol_version == LocalProtocolVersion.L01
362+
assert channel._params is not None
363+
assert channel._params.ack_nonce == 11111
273364
assert channel._is_connected is True

0 commit comments

Comments
 (0)