|
8 | 8 | import pytest |
9 | 9 |
|
10 | 10 | from roborock.devices.local_channel import LocalChannel, LocalChannelParams |
11 | | -from roborock.exceptions import RoborockConnectionException |
| 11 | +from roborock.exceptions import RoborockConnectionException, RoborockException |
12 | 12 | from roborock.protocol import create_local_decoder, create_local_encoder |
13 | 13 | from roborock.protocols.v1_protocol import LocalProtocolVersion |
14 | 14 | from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol |
@@ -61,7 +61,7 @@ async def setup_local_channel_with_hello_mock() -> LocalChannel: |
61 | 61 | """Fixture to set up the local channel with automatic hello mocking.""" |
62 | 62 | channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY) |
63 | 63 |
|
64 | | - async def mock_do_hello(local_protocol_version): |
| 64 | + async def mock_do_hello(_: LocalProtocolVersion): |
65 | 65 | """Mock _do_hello to return successful params without sending actual request.""" |
66 | 66 | return LocalChannelParams(local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=54321) |
67 | 67 |
|
@@ -269,5 +269,96 @@ async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalCh |
269 | 269 |
|
270 | 270 | # Verify that the channel is using L01 protocol |
271 | 271 | assert channel._local_protocol_version == LocalProtocolVersion.L01 |
272 | | - assert channel._ack_nonce == 54321 |
| 272 | + assert channel._params is not None |
| 273 | + assert channel._params.ack_nonce == 54321 |
| 274 | + assert channel._is_connected is True |
| 275 | + |
| 276 | + |
| 277 | +async def test_hello_success_with_v1_protocol_first(mock_loop: Mock, mock_transport: Mock) -> None: |
| 278 | + """Test that when V1 protocol succeeds on first attempt, we use V1.""" |
| 279 | + |
| 280 | + # Create a channel without the automatic hello mocking |
| 281 | + channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY) |
| 282 | + |
| 283 | + # Mock _do_hello to succeed for V1 on first attempt |
| 284 | + async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None: |
| 285 | + if local_protocol_version == LocalProtocolVersion.V1: |
| 286 | + # V1 succeeds on first attempt |
| 287 | + return LocalChannelParams( |
| 288 | + local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=67890 |
| 289 | + ) |
| 290 | + elif local_protocol_version == LocalProtocolVersion.L01: |
| 291 | + # L01 would succeed but we shouldn't reach it |
| 292 | + return LocalChannelParams( |
| 293 | + local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=99999 |
| 294 | + ) |
| 295 | + return None |
| 296 | + |
| 297 | + # Replace the _do_hello method |
| 298 | + setattr(channel, "_do_hello", mock_do_hello) |
| 299 | + |
| 300 | + # Connect and verify V1 protocol is used |
| 301 | + await channel.connect() |
| 302 | + |
| 303 | + # Verify that the channel is using V1 protocol |
| 304 | + assert channel._local_protocol_version == LocalProtocolVersion.V1 |
| 305 | + assert channel._params is not None |
| 306 | + assert channel._params.ack_nonce == 67890 |
| 307 | + assert channel._is_connected is True |
| 308 | + |
| 309 | + |
| 310 | +async def test_hello_both_protocols_fail(mock_loop: Mock, mock_transport: Mock) -> None: |
| 311 | + """Test that when both V1 and L01 protocols fail, connection fails.""" |
| 312 | + |
| 313 | + # Create a channel without the automatic hello mocking |
| 314 | + channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY) |
| 315 | + |
| 316 | + # Mock _do_hello to fail for both protocols |
| 317 | + async def mock_do_hello(_: LocalProtocolVersion) -> LocalChannelParams | None: |
| 318 | + # Both protocols fail |
| 319 | + return None |
| 320 | + |
| 321 | + # Replace the _do_hello method |
| 322 | + setattr(channel, "_do_hello", mock_do_hello) |
| 323 | + |
| 324 | + # Connect should raise an exception |
| 325 | + with pytest.raises(RoborockException, match="Failed to connect to device with any known protocol"): |
| 326 | + await channel.connect() |
| 327 | + |
| 328 | + # Verify that the channel is not connected and cleaned up |
| 329 | + assert channel._is_connected is False |
| 330 | + assert channel._transport is None |
| 331 | + |
| 332 | + |
| 333 | +async def test_hello_preferred_protocol_version_ordering(mock_loop: Mock, mock_transport: Mock) -> None: |
| 334 | + """Test that preferred protocol version is tried first.""" |
| 335 | + |
| 336 | + # Create a channel with preferred L01 protocol |
| 337 | + channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY) |
| 338 | + channel._local_protocol_version = LocalProtocolVersion.L01 |
| 339 | + |
| 340 | + # Track which protocols were attempted and in what order |
| 341 | + attempted_protocols: list[LocalProtocolVersion] = [] |
| 342 | + |
| 343 | + # Mock _do_hello to track attempts and succeed on L01 |
| 344 | + async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None: |
| 345 | + attempted_protocols.append(local_protocol_version) |
| 346 | + if local_protocol_version == LocalProtocolVersion.L01: |
| 347 | + # L01 succeeds |
| 348 | + return LocalChannelParams( |
| 349 | + local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=11111 |
| 350 | + ) |
| 351 | + return None |
| 352 | + |
| 353 | + # Replace the _do_hello method |
| 354 | + setattr(channel, "_do_hello", mock_do_hello) |
| 355 | + |
| 356 | + # Connect and verify L01 is tried first |
| 357 | + await channel.connect() |
| 358 | + |
| 359 | + # Verify that L01 was tried first (preferred version) |
| 360 | + assert attempted_protocols == [LocalProtocolVersion.L01] |
| 361 | + assert channel._local_protocol_version == LocalProtocolVersion.L01 |
| 362 | + assert channel._params is not None |
| 363 | + assert channel._params.ack_nonce == 11111 |
273 | 364 | assert channel._is_connected is True |
0 commit comments