|
5 | 5 | close to an "end to end" test without needing an actual MQTT broker server. |
6 | 6 | """ |
7 | 7 |
|
8 | | -import asyncio |
9 | | -from collections.abc import AsyncGenerator, Callable, Generator |
| 8 | +from collections.abc import AsyncGenerator, Callable |
10 | 9 | from queue import Queue |
11 | | -from typing import Any |
12 | | -from unittest.mock import patch |
13 | 10 |
|
14 | | -import paho.mqtt.client as mqtt |
15 | 11 | import pytest |
16 | 12 |
|
17 | 13 | from roborock.mqtt.roborock_session import create_mqtt_session |
18 | | -from roborock.mqtt.session import MqttParams, MqttSession |
| 14 | +from roborock.mqtt.session import MqttSession |
19 | 15 | from roborock.protocol import MessageParser |
20 | 16 | from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol |
21 | 17 | from tests import mqtt_packet |
22 | | -from tests.conftest import FakeSocketHandler |
23 | 18 | from tests.mock_data import LOCAL_KEY |
24 | | - |
25 | | -FAKE_PARAMS = MqttParams( |
26 | | - host="localhost", |
27 | | - port=1883, |
28 | | - tls=False, |
29 | | - username="username", |
30 | | - password="password", |
31 | | - timeout=10.0, |
| 19 | +from tests.mqtt.common import ( |
| 20 | + FAKE_PARAMS, |
| 21 | + Subscriber, |
32 | 22 | ) |
33 | 23 |
|
34 | 24 |
|
35 | 25 | @pytest.fixture(autouse=True) |
36 | | -async def mock_client_fixture() -> AsyncGenerator[None, None]: |
37 | | - """Fixture to patch the MQTT underlying sync client. |
38 | | -
|
39 | | - The tests use fake sockets, so this ensures that the async mqtt client does not |
40 | | - attempt to listen on them directly. We instead just poll the socket for |
41 | | - data ourselves. |
42 | | - """ |
43 | | - |
44 | | - event_loop = asyncio.get_running_loop() |
45 | | - |
46 | | - orig_class = mqtt.Client |
47 | | - |
48 | | - async def poll_sockets(client: mqtt.Client) -> None: |
49 | | - """Poll the mqtt client sockets in a loop to pick up new data.""" |
50 | | - while True: |
51 | | - event_loop.call_soon_threadsafe(client.loop_read) |
52 | | - event_loop.call_soon_threadsafe(client.loop_write) |
53 | | - await asyncio.sleep(0.01) |
54 | | - |
55 | | - task: asyncio.Task[None] | None = None |
56 | | - |
57 | | - def new_client(*args: Any, **kwargs: Any) -> mqtt.Client: |
58 | | - """Create a new mqtt client and start the socket polling task.""" |
59 | | - nonlocal task |
60 | | - client = orig_class(*args, **kwargs) |
61 | | - task = event_loop.create_task(poll_sockets(client)) |
62 | | - return client |
63 | | - |
64 | | - with ( |
65 | | - patch("aiomqtt.client.Client._on_socket_open"), |
66 | | - patch("aiomqtt.client.Client._on_socket_close"), |
67 | | - patch("aiomqtt.client.Client._on_socket_register_write"), |
68 | | - patch("aiomqtt.client.Client._on_socket_unregister_write"), |
69 | | - patch("aiomqtt.client.mqtt.Client", side_effect=new_client), |
70 | | - ): |
71 | | - yield |
72 | | - if task: |
73 | | - task.cancel() |
| 26 | +def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None: |
| 27 | + """Automatically use the mock mqtt client fixture.""" |
74 | 28 |
|
75 | 29 |
|
76 | 30 | @pytest.fixture(autouse=True) |
77 | | -def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None: |
78 | | - """Fixture to mock the MQTT connection.""" |
| 31 | +def auto_fast_backoff(fast_backoff_fixture: None) -> None: |
| 32 | + """Automatically use the fast backoff fixture.""" |
79 | 33 |
|
80 | 34 |
|
81 | 35 | @pytest.fixture(autouse=True) |
82 | | -def fast_backoff_fixture() -> Generator[None, None, None]: |
83 | | - """Fixture to speed up backoff.""" |
84 | | - with patch("roborock.mqtt.roborock_session.MIN_BACKOFF_INTERVAL", 0.01): |
85 | | - yield |
86 | | - |
87 | | - |
88 | | -@pytest.fixture |
89 | | -def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]: |
90 | | - """Fixture to push a response to the client.""" |
91 | | - |
92 | | - def _push(data: bytes) -> None: |
93 | | - response_queue.put(data) |
94 | | - fake_socket_handler.push_response() |
| 36 | +def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None: |
| 37 | + """Fixture to mock the MQTT connection. |
95 | 38 |
|
96 | | - return _push |
| 39 | + This is here to pull in the mock socket pixtures into all tests used here. |
| 40 | + """ |
97 | 41 |
|
98 | 42 |
|
99 | 43 | @pytest.fixture(name="session") |
100 | | -async def session_fixture(push_response: Callable[[bytes], None]) -> AsyncGenerator[MqttSession, None]: |
| 44 | +async def session_fixture( |
| 45 | + push_response: Callable[[bytes], None], |
| 46 | +) -> AsyncGenerator[MqttSession, None]: |
101 | 47 | """Fixture to create a new connected MQTT session.""" |
| 48 | + push_response(mqtt_packet.gen_connack(rc=0, flags=2)) |
| 49 | + session = await create_mqtt_session(FAKE_PARAMS) |
| 50 | + assert session.connected |
102 | 51 | try: |
103 | | - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) |
104 | | - session = await create_mqtt_session(FAKE_PARAMS) |
105 | | - assert session.connected |
106 | 52 | yield session |
107 | 53 | finally: |
108 | 54 | await session.close() |
109 | 55 |
|
110 | 56 |
|
111 | | -class Subscriber: |
112 | | - """Mock subscriber class. |
113 | | -
|
114 | | - We use this to hold on to received messages for verification. |
115 | | - """ |
116 | | - |
117 | | - def __init__(self) -> None: |
118 | | - self.messages: list[bytes] = [] |
119 | | - self._event = asyncio.Event() |
120 | | - |
121 | | - def append(self, message: bytes) -> None: |
122 | | - self.messages.append(message) |
123 | | - self._event.set() |
124 | | - |
125 | | - async def wait(self) -> None: |
126 | | - await asyncio.wait_for(self._event.wait(), timeout=1.0) |
127 | | - self._event.clear() |
128 | | - |
129 | | - |
130 | 57 | async def test_session_e2e_receive_message(push_response: Callable[[bytes], None], session: MqttSession) -> None: |
131 | 58 | """Test receiving a real Roborock message through the session.""" |
132 | 59 | assert session.connected |
@@ -163,11 +90,13 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None |
163 | 90 |
|
164 | 91 |
|
165 | 92 | async def test_session_e2e_publish_message( |
166 | | - push_response: Callable[[bytes], None], received_requests: Queue, session: MqttSession |
| 93 | + push_response: Callable[[bytes], None], |
| 94 | + received_requests: Queue, |
| 95 | + session: MqttSession, |
167 | 96 | ) -> None: |
168 | 97 | """Test publishing a real Roborock message.""" |
169 | 98 |
|
170 | | - # Publish a message to the brokwer |
| 99 | + # Publish a message to the broker |
171 | 100 | msg = RoborockMessage( |
172 | 101 | protocol=RoborockMessageProtocol.RPC_REQUEST, |
173 | 102 | payload=b'{"method":"get_status"}', |
|
0 commit comments