Skip to content

Commit c92f29c

Browse files
committed
chore: Small tweaks to test fixtures
These are improvements factored out of a large change to add more e2e tests for device manager.
1 parent 6293a67 commit c92f29c

File tree

5 files changed

+20
-10
lines changed

5 files changed

+20
-10
lines changed

roborock/devices/local_channel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def connection_lost(self, exc: Exception | None) -> None:
4545
self.connection_lost_cb(exc)
4646

4747

48+
def get_running_loop() -> asyncio.AbstractEventLoop:
49+
"""Get the running event loop, extracted for mocking purposes."""
50+
return asyncio.get_running_loop()
51+
52+
4853
class LocalChannel(Channel):
4954
"""Simple RPC-style channel for communicating with a device over a local network.
5055
@@ -179,7 +184,7 @@ async def connect(self) -> None:
179184
if self._is_connected:
180185
self._logger.debug("Unexpected call to connect when already connected")
181186
return
182-
loop = asyncio.get_running_loop()
187+
loop = get_running_loop()
183188
protocol = _LocalProtocol(self._data_received, self._connection_lost)
184189
try:
185190
self._transport, self._protocol = await loop.create_connection(lambda: protocol, self._host, _PORT)

tests/devices/test_local_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def setup_mock_loop(mock_transport: Mock) -> Generator[Mock, None, None]:
5252
loop = Mock()
5353
loop.create_connection = AsyncMock(return_value=(mock_transport, Mock()))
5454

55-
with patch("asyncio.get_running_loop", return_value=loop):
55+
with patch("roborock.devices.local_channel.get_running_loop", return_value=loop):
5656
yield loop
5757

5858

tests/fixtures/aiomqtt_fixtures.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ async def mock_aiomqtt_client_fixture() -> AsyncGenerator[None, None]:
2828

2929
async def poll_sockets(client: mqtt.Client) -> None:
3030
"""Poll the mqtt client sockets in a loop to pick up new data."""
31-
while True:
32-
event_loop.call_soon_threadsafe(client.loop_read)
33-
event_loop.call_soon_threadsafe(client.loop_write)
34-
await asyncio.sleep(0.01)
31+
try:
32+
while True:
33+
event_loop.call_soon_threadsafe(client.loop_read)
34+
event_loop.call_soon_threadsafe(client.loop_write)
35+
await asyncio.sleep(0.01)
36+
except asyncio.CancelledError:
37+
pass
3538

3639
task: asyncio.Task[None] | None = None
3740

@@ -52,6 +55,7 @@ def new_client(*args: Any, **kwargs: Any) -> mqtt.Client:
5255
yield
5356
if task:
5457
task.cancel()
58+
await task
5559

5660

5761
@pytest.fixture

tests/fixtures/local_async_fixtures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def start_handle_write(data: bytes) -> None:
7979

8080
return (mock_transport, protocol)
8181

82-
with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop:
82+
with patch("roborock.devices.local_channel.get_running_loop") as mock_loop:
8383
mock_loop.return_value.create_connection.side_effect = create_connection
8484
yield
8585

tests/fixtures/mqtt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
self.handle_request = handle_request
3333
self.response_queue = response_queue
3434
self.log = log
35+
self.client_connected = False
3536

3637
def pending(self) -> int:
3738
"""Return the number of bytes in the response buffer."""
@@ -45,31 +46,31 @@ def handle_socket_recv(self, read_size: int) -> bytes:
4546
self.response_buf.seek(0)
4647
data = self.response_buf.read(read_size)
4748
_LOGGER.debug("Response: 0x%s", data.hex())
49+
self.log.add_log_entry("[mqtt <]", data)
4850
# Consume the rest of the data in the buffer
4951
remaining_data = self.response_buf.read()
5052
self.response_buf = io.BytesIO(remaining_data)
5153
return data
5254

5355
def handle_socket_send(self, client_request: bytes) -> int:
5456
"""Receive an incoming request from the client."""
57+
self.client_connected = True
5558
_LOGGER.debug("Request: 0x%s", client_request.hex())
5659
self.log.add_log_entry("[mqtt >]", client_request)
5760
if (response := self.handle_request(client_request)) is not None:
5861
# Enqueue a response to be sent back to the client in the buffer.
5962
# The buffer will be emptied when the client calls recv() on the socket
6063
_LOGGER.debug("Queued: 0x%s", response.hex())
61-
self.log.add_log_entry("[mqtt <]", response)
6264
self.response_buf.write(response)
6365
return len(client_request)
6466

6567
def push_response(self) -> None:
6668
"""Push a response to the client."""
67-
if not self.response_queue.empty():
69+
if not self.response_queue.empty() and self.client_connected:
6870
response = self.response_queue.get()
6971
# Enqueue a response to be sent back to the client in the buffer.
7072
# The buffer will be emptied when the client calls recv() on the socket
7173
_LOGGER.debug("Queued: 0x%s", response.hex())
72-
self.log.add_log_entry("[mqtt <]", response)
7374
self.response_buf.write(response)
7475

7576

0 commit comments

Comments
 (0)