Skip to content

Commit 66e828b

Browse files
committed
chore: share test code between e2e tests and mqtt tests
1 parent 833ea39 commit 66e828b

File tree

3 files changed

+135
-184
lines changed

3 files changed

+135
-184
lines changed

tests/e2e/test_mqtt_session.py

Lines changed: 23 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,128 +5,55 @@
55
close to an "end to end" test without needing an actual MQTT broker server.
66
"""
77

8-
import asyncio
9-
from collections.abc import AsyncGenerator, Callable, Generator
8+
from collections.abc import AsyncGenerator, Callable
109
from queue import Queue
11-
from typing import Any
12-
from unittest.mock import patch
1310

14-
import paho.mqtt.client as mqtt
1511
import pytest
1612

1713
from roborock.mqtt.roborock_session import create_mqtt_session
18-
from roborock.mqtt.session import MqttParams, MqttSession
14+
from roborock.mqtt.session import MqttSession
1915
from roborock.protocol import MessageParser
2016
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
2117
from tests import mqtt_packet
22-
from tests.conftest import FakeSocketHandler
2318
from tests.mock_data import LOCAL_KEY
24-
25-
FAKE_PARAMS = MqttParams(
26-
host="localhost",
27-
port=1883,
28-
tls=False,
29-
username="username",
30-
password="password",
31-
timeout=10.0,
19+
from tests.mqtt.common import (
20+
FAKE_PARAMS,
21+
Subscriber,
3222
)
3323

3424

3525
@pytest.fixture(autouse=True)
36-
async def mock_client_fixture() -> AsyncGenerator[None, None]:
37-
"""Fixture to patch the MQTT underlying sync client.
38-
39-
The tests use fake sockets, so this ensures that the async mqtt client does not
40-
attempt to listen on them directly. We instead just poll the socket for
41-
data ourselves.
42-
"""
43-
44-
event_loop = asyncio.get_running_loop()
45-
46-
orig_class = mqtt.Client
47-
48-
async def poll_sockets(client: mqtt.Client) -> None:
49-
"""Poll the mqtt client sockets in a loop to pick up new data."""
50-
while True:
51-
event_loop.call_soon_threadsafe(client.loop_read)
52-
event_loop.call_soon_threadsafe(client.loop_write)
53-
await asyncio.sleep(0.01)
54-
55-
task: asyncio.Task[None] | None = None
56-
57-
def new_client(*args: Any, **kwargs: Any) -> mqtt.Client:
58-
"""Create a new mqtt client and start the socket polling task."""
59-
nonlocal task
60-
client = orig_class(*args, **kwargs)
61-
task = event_loop.create_task(poll_sockets(client))
62-
return client
63-
64-
with (
65-
patch("aiomqtt.client.Client._on_socket_open"),
66-
patch("aiomqtt.client.Client._on_socket_close"),
67-
patch("aiomqtt.client.Client._on_socket_register_write"),
68-
patch("aiomqtt.client.Client._on_socket_unregister_write"),
69-
patch("aiomqtt.client.mqtt.Client", side_effect=new_client),
70-
):
71-
yield
72-
if task:
73-
task.cancel()
26+
def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None:
27+
"""Automatically use the mock mqtt client fixture."""
7428

7529

7630
@pytest.fixture(autouse=True)
77-
def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None:
78-
"""Fixture to mock the MQTT connection."""
31+
def auto_fast_backoff(fast_backoff_fixture: None) -> None:
32+
"""Automatically use the fast backoff fixture."""
7933

8034

8135
@pytest.fixture(autouse=True)
82-
def fast_backoff_fixture() -> Generator[None, None, None]:
83-
"""Fixture to speed up backoff."""
84-
with patch("roborock.mqtt.roborock_session.MIN_BACKOFF_INTERVAL", 0.01):
85-
yield
86-
87-
88-
@pytest.fixture
89-
def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]:
90-
"""Fixture to push a response to the client."""
91-
92-
def _push(data: bytes) -> None:
93-
response_queue.put(data)
94-
fake_socket_handler.push_response()
36+
def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None:
37+
"""Fixture to mock the MQTT connection.
9538
96-
return _push
39+
This is here to pull in the mock socket pixtures into all tests used here.
40+
"""
9741

9842

9943
@pytest.fixture(name="session")
100-
async def session_fixture(push_response: Callable[[bytes], None]) -> AsyncGenerator[MqttSession, None]:
44+
async def session_fixture(
45+
push_response: Callable[[bytes], None],
46+
) -> AsyncGenerator[MqttSession, None]:
10147
"""Fixture to create a new connected MQTT session."""
48+
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
49+
session = await create_mqtt_session(FAKE_PARAMS)
50+
assert session.connected
10251
try:
103-
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
104-
session = await create_mqtt_session(FAKE_PARAMS)
105-
assert session.connected
10652
yield session
10753
finally:
10854
await session.close()
10955

11056

111-
class Subscriber:
112-
"""Mock subscriber class.
113-
114-
We use this to hold on to received messages for verification.
115-
"""
116-
117-
def __init__(self) -> None:
118-
self.messages: list[bytes] = []
119-
self._event = asyncio.Event()
120-
121-
def append(self, message: bytes) -> None:
122-
self.messages.append(message)
123-
self._event.set()
124-
125-
async def wait(self) -> None:
126-
await asyncio.wait_for(self._event.wait(), timeout=1.0)
127-
self._event.clear()
128-
129-
13057
async def test_session_e2e_receive_message(push_response: Callable[[bytes], None], session: MqttSession) -> None:
13158
"""Test receiving a real Roborock message through the session."""
13259
assert session.connected
@@ -163,11 +90,13 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None
16390

16491

16592
async def test_session_e2e_publish_message(
166-
push_response: Callable[[bytes], None], received_requests: Queue, session: MqttSession
93+
push_response: Callable[[bytes], None],
94+
received_requests: Queue,
95+
session: MqttSession,
16796
) -> None:
16897
"""Test publishing a real Roborock message."""
16998

170-
# Publish a message to the brokwer
99+
# Publish a message to the broker
171100
msg = RoborockMessage(
172101
protocol=RoborockMessageProtocol.RPC_REQUEST,
173102
payload=b'{"method":"get_status"}',

tests/mqtt/common.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Common code for MQTT tests."""
2+
3+
import asyncio
4+
import datetime
5+
from collections.abc import AsyncGenerator, Callable, Generator
6+
from queue import Queue
7+
from typing import Any
8+
from unittest.mock import patch
9+
10+
import paho.mqtt.client as mqtt
11+
import pytest
12+
13+
from roborock.mqtt.session import MqttParams
14+
from tests.conftest import FakeSocketHandler
15+
16+
FAKE_PARAMS = MqttParams(
17+
host="localhost",
18+
port=1883,
19+
tls=False,
20+
username="username",
21+
password="password",
22+
timeout=10.0,
23+
)
24+
25+
26+
class Subscriber:
27+
"""Mock subscriber class.
28+
29+
We use this to hold on to received messages for verification.
30+
"""
31+
32+
def __init__(self) -> None:
33+
self.messages: list[bytes] = []
34+
self._event = asyncio.Event()
35+
36+
def append(self, message: bytes) -> None:
37+
self.messages.append(message)
38+
self._event.set()
39+
40+
async def wait(self) -> None:
41+
await asyncio.wait_for(self._event.wait(), timeout=1.0)
42+
self._event.clear()
43+
44+
45+
@pytest.fixture
46+
async def mock_mqtt_client_fixture() -> AsyncGenerator[None, None]:
47+
"""Fixture to patch the MQTT underlying sync client.
48+
49+
The tests use fake sockets, so this ensures that the async mqtt client does not
50+
attempt to listen on them directly. We instead just poll the socket for
51+
data ourselves.
52+
"""
53+
54+
event_loop = asyncio.get_running_loop()
55+
56+
orig_class = mqtt.Client
57+
58+
async def poll_sockets(client: mqtt.Client) -> None:
59+
"""Poll the mqtt client sockets in a loop to pick up new data."""
60+
while True:
61+
event_loop.call_soon_threadsafe(client.loop_read)
62+
event_loop.call_soon_threadsafe(client.loop_write)
63+
await asyncio.sleep(0.01)
64+
65+
task: asyncio.Task[None] | None = None
66+
67+
def new_client(*args: Any, **kwargs: Any) -> mqtt.Client:
68+
"""Create a new mqtt client and start the socket polling task."""
69+
nonlocal task
70+
client = orig_class(*args, **kwargs)
71+
task = event_loop.create_task(poll_sockets(client))
72+
return client
73+
74+
with (
75+
patch("aiomqtt.client.Client._on_socket_open"),
76+
patch("aiomqtt.client.Client._on_socket_close"),
77+
patch("aiomqtt.client.Client._on_socket_register_write"),
78+
patch("aiomqtt.client.Client._on_socket_unregister_write"),
79+
patch("aiomqtt.client.mqtt.Client", side_effect=new_client),
80+
):
81+
yield
82+
if task:
83+
task.cancel()
84+
85+
86+
@pytest.fixture
87+
def fast_backoff_fixture() -> Generator[None, None, None]:
88+
"""Fixture to speed up backoff."""
89+
with patch(
90+
"roborock.mqtt.roborock_session.MIN_BACKOFF_INTERVAL",
91+
datetime.timedelta(seconds=0.01),
92+
):
93+
yield
94+
95+
96+
@pytest.fixture
97+
def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]:
98+
"""Fixture to push a response to the client."""
99+
100+
def _push(data: bytes) -> None:
101+
response_queue.put(data)
102+
fake_socket_handler.push_response()
103+
104+
return _push

tests/mqtt/test_roborock_session.py

Lines changed: 8 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,18 @@
22

33
import asyncio
44
import datetime
5-
from collections.abc import AsyncGenerator, Callable, Generator
6-
from queue import Queue
7-
from typing import Any
5+
from collections.abc import Callable, Generator
86
from unittest.mock import AsyncMock, Mock, patch
97

108
import aiomqtt
11-
import paho.mqtt.client as mqtt
129
import pytest
1310

1411
from roborock.mqtt.roborock_session import RoborockMqttSession, create_mqtt_session
1512
from roborock.mqtt.session import MqttParams, MqttSessionException, MqttSessionUnauthorized
1613
from tests import mqtt_packet
17-
from tests.conftest import FakeSocketHandler
18-
19-
# We mock out the connection so these params are not used/verified
20-
FAKE_PARAMS = MqttParams(
21-
host="localhost",
22-
port=1883,
23-
tls=False,
24-
username="username",
25-
password="password",
26-
timeout=10.0,
14+
from tests.mqtt.common import (
15+
FAKE_PARAMS,
16+
Subscriber,
2717
)
2818

2919

@@ -33,51 +23,13 @@ def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None
3323

3424

3525
@pytest.fixture(autouse=True)
36-
async def mock_client_fixture() -> AsyncGenerator[None, None]:
37-
"""Fixture to patch the MQTT underlying sync client.
38-
39-
The tests use fake sockets, so this ensures that the async mqtt client does not
40-
attempt to listen on them directly. We instead just poll the socket for
41-
data ourselves.
42-
"""
43-
44-
event_loop = asyncio.get_running_loop()
45-
46-
orig_class = mqtt.Client
47-
48-
async def poll_sockets(client: mqtt.Client) -> None:
49-
"""Poll the mqtt client sockets in a loop to pick up new data."""
50-
while True:
51-
event_loop.call_soon_threadsafe(client.loop_read)
52-
event_loop.call_soon_threadsafe(client.loop_write)
53-
await asyncio.sleep(0.1)
54-
55-
task: asyncio.Task[None] | None = None
56-
57-
def new_client(*args: Any, **kwargs: Any) -> mqtt.Client:
58-
"""Create a new mqtt client and start the socket polling task."""
59-
nonlocal task
60-
client = orig_class(*args, **kwargs)
61-
task = event_loop.create_task(poll_sockets(client))
62-
return client
63-
64-
with (
65-
patch("aiomqtt.client.Client._on_socket_open"),
66-
patch("aiomqtt.client.Client._on_socket_close"),
67-
patch("aiomqtt.client.Client._on_socket_register_write"),
68-
patch("aiomqtt.client.Client._on_socket_unregister_write"),
69-
patch("aiomqtt.client.mqtt.Client", side_effect=new_client),
70-
):
71-
yield
72-
if task:
73-
task.cancel()
26+
def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None:
27+
"""Automatically use the mock mqtt client fixture."""
7428

7529

7630
@pytest.fixture(autouse=True)
77-
def fast_backoff_fixture() -> Generator[None, None, None]:
78-
"""Fixture to make backoff intervals fast."""
79-
with patch("roborock.mqtt.roborock_session.MIN_BACKOFF_INTERVAL", datetime.timedelta(seconds=0.01)):
80-
yield
31+
def auto_fast_backoff(fast_backoff_fixture: None) -> None:
32+
"""Automatically use the fast backoff fixture."""
8133

8234

8335
@pytest.fixture
@@ -97,40 +49,6 @@ def mock_mqtt_client() -> Generator[AsyncMock, None, None]:
9749
yield mock_client
9850

9951

100-
@pytest.fixture
101-
def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]:
102-
"""Fixtures to push messages."""
103-
104-
def push(message: bytes) -> None:
105-
response_queue.put(message)
106-
fake_socket_handler.push_response()
107-
108-
return push
109-
110-
111-
class Subscriber:
112-
"""Mock subscriber class.
113-
114-
This will capture messages published on the session so the tests can verify
115-
they were received.
116-
"""
117-
118-
def __init__(self) -> None:
119-
"""Initialize the subscriber."""
120-
self.messages: list[bytes] = []
121-
self.event: asyncio.Event = asyncio.Event()
122-
123-
def append(self, message: bytes) -> None:
124-
"""Append a message to the subscriber."""
125-
self.messages.append(message)
126-
self.event.set()
127-
128-
async def wait(self) -> None:
129-
"""Wait for a message to be received."""
130-
await self.event.wait()
131-
self.event.clear()
132-
133-
13452
async def test_session(push_response: Callable[[bytes], None]) -> None:
13553
"""Test the MQTT session."""
13654

0 commit comments

Comments
 (0)