Skip to content

Commit 65bb013

Browse files
committed
fix: Update mqtt channel to correctly handle multiple subscribers
1 parent c31bf19 commit 65bb013

File tree

4 files changed

+177
-48
lines changed

4 files changed

+177
-48
lines changed

roborock/devices/mqtt_channel.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot:
3030

3131
self._decoder = create_mqtt_decoder(local_key)
3232
self._encoder = create_mqtt_encoder(local_key)
33-
self._mqtt_unsub: Callable[[], None] | None = None
3433

3534
@property
3635
def is_connected(self) -> bool:
37-
"""Return true if the channel is connected."""
38-
return (self._mqtt_unsub is not None) and self._mqtt_session.connected
36+
"""Return true if the channel is connected.
37+
38+
This passes through the underlying MQTT session's connected state.
39+
"""
40+
return self._mqtt_session.connected
3941

4042
@property
4143
def _publish_topic(self) -> str:
@@ -52,9 +54,6 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
5254
5355
The callback will be called with the message payload when a message is received.
5456
55-
All messages received will be processed through the provided callback, even
56-
those sent in response to the `send_command` command.
57-
5857
Returns a callable that can be used to unsubscribe from the topic.
5958
"""
6059

@@ -69,14 +68,7 @@ def message_handler(payload: bytes) -> None:
6968
except Exception as e:
7069
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
7170

72-
self._mqtt_unsub = await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
73-
74-
def unsub_wrapper() -> None:
75-
if self._mqtt_unsub is not None:
76-
self._mqtt_unsub()
77-
self._mqtt_unsub = None
78-
79-
return unsub_wrapper
71+
return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
8072

8173
async def publish(self, message: RoborockMessage) -> None:
8274
"""Publish a command message.

roborock/devices/v1_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def is_local_connected(self) -> bool:
7979
@property
8080
def is_mqtt_connected(self) -> bool:
8181
"""Return whether MQTT connection is available."""
82-
return self._mqtt_unsub is not None
82+
return self._mqtt_unsub is not None and self._mqtt_channel.is_connected
8383

8484
@property
8585
def rpc_channel(self) -> V1RpcChannel:

tests/devices/test_mqtt_channel.py

Lines changed: 132 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import json
55
import logging
6-
from collections.abc import Callable, Generator
6+
from collections.abc import AsyncGenerator, Callable, Generator
77
from unittest.mock import AsyncMock, Mock, patch
88

99
import pytest
@@ -63,23 +63,29 @@ def setup_mqtt_channel(mqtt_session: Mock) -> MqttChannel:
6363
)
6464

6565

66-
@pytest.fixture(name="received_messages", autouse=True)
67-
async def setup_subscribe_callback(mqtt_channel: MqttChannel) -> list[RoborockMessage]:
66+
@pytest.fixture(name="mqtt_subscribers", autouse=True)
67+
async def setup_subscribe_callback(mqtt_session: Mock) -> AsyncGenerator[list[Callable[[bytes], None]], None]:
6868
"""Fixture to record messages received by the subscriber."""
69-
messages: list[RoborockMessage] = []
70-
await mqtt_channel.subscribe(messages.append)
71-
return messages
69+
subscriber_callbacks = []
70+
71+
def mock_subscribe(_: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
72+
subscriber_callbacks.append(callback)
73+
return lambda: subscriber_callbacks.remove(callback)
74+
75+
mqtt_session.subscribe.side_effect = mock_subscribe
76+
yield subscriber_callbacks
77+
assert not subscriber_callbacks, "Not all subscribers were unsubscribed"
7278

7379

7480
@pytest.fixture(name="mqtt_message_handler")
75-
async def setup_message_handler(mqtt_session: Mock, mqtt_channel: MqttChannel) -> Callable[[bytes], None]:
81+
async def setup_message_handler(mqtt_subscribers: list[Callable[[bytes], None]]) -> Callable[[bytes], None]:
7682
"""Fixture to allow simulating incoming MQTT messages."""
77-
# Subscribe to set up message handling. We grab the message handler callback
78-
# and use it to simulate receiving a response.
79-
assert mqtt_session.subscribe
80-
subscribe_call_args = mqtt_session.subscribe.call_args
81-
message_handler = subscribe_call_args[0][1]
82-
return message_handler
83+
84+
def invoke_all_callbacks(message: bytes) -> None:
85+
for callback in mqtt_subscribers:
86+
callback(message)
87+
88+
return invoke_all_callbacks
8389

8490

8591
@pytest.fixture
@@ -106,23 +112,6 @@ async def mock_home_data() -> HomeData:
106112
return HomeData.from_dict(mock_data.HOME_DATA_RAW)
107113

108114

109-
async def test_mqtt_channel(mqtt_session: Mock, mqtt_channel: MqttChannel) -> None:
110-
"""Test MQTT channel setup."""
111-
112-
unsub = Mock()
113-
mqtt_session.subscribe.return_value = unsub
114-
115-
callback = Mock()
116-
result = await mqtt_channel.subscribe(callback)
117-
118-
assert mqtt_session.subscribe.called
119-
assert mqtt_session.subscribe.call_args[0][0] == "rr/m/o/user123/username/abc123"
120-
121-
unsub.assert_not_called()
122-
result()
123-
unsub.assert_called_once()
124-
125-
126115
async def test_publish_success(
127116
mqtt_session: Mock,
128117
mqtt_channel: MqttChannel,
@@ -148,13 +137,126 @@ async def test_publish_success(
148137

149138

150139
async def test_message_decode_error(
140+
mqtt_channel: MqttChannel,
151141
mqtt_message_handler: Callable[[bytes], None],
152142
caplog: pytest.LogCaptureFixture,
153143
) -> None:
154144
"""Test an error during message decoding."""
145+
callback = Mock()
146+
unsub = await mqtt_channel.subscribe(callback)
147+
155148
mqtt_message_handler(b"invalid_payload")
156149
await asyncio.sleep(0.01) # yield
157150

158151
assert len(caplog.records) == 1
159152
assert caplog.records[0].levelname == "WARNING"
160153
assert "Failed to decode MQTT message" in caplog.records[0].message
154+
unsub()
155+
156+
157+
async def test_concurrent_subscribers(mqtt_session: Mock, mqtt_channel: MqttChannel) -> None:
158+
"""Test multiple concurrent subscribers receive all messages."""
159+
# Set up multiple subscribers
160+
subscriber1_messages: list[RoborockMessage] = []
161+
subscriber2_messages: list[RoborockMessage] = []
162+
subscriber3_messages: list[RoborockMessage] = []
163+
164+
unsub1 = await mqtt_channel.subscribe(subscriber1_messages.append)
165+
unsub2 = await mqtt_channel.subscribe(subscriber2_messages.append)
166+
unsub3 = await mqtt_channel.subscribe(subscriber3_messages.append)
167+
168+
# Verify that each subscription creates a separate call to the MQTT session
169+
assert mqtt_session.subscribe.call_count == 3
170+
171+
# All subscriptions should be to the same topic
172+
for call in mqtt_session.subscribe.call_args_list:
173+
assert call[0][0] == "rr/m/o/user123/username/abc123"
174+
175+
# Get the message handlers for each subscriber
176+
handler1 = mqtt_session.subscribe.call_args_list[0][0][1]
177+
handler2 = mqtt_session.subscribe.call_args_list[1][0][1]
178+
handler3 = mqtt_session.subscribe.call_args_list[2][0][1]
179+
180+
# Simulate receiving messages - each handler should decode the message independently
181+
handler1(ENCODER(TEST_REQUEST))
182+
handler2(ENCODER(TEST_REQUEST))
183+
handler3(ENCODER(TEST_REQUEST))
184+
await asyncio.sleep(0.01) # yield
185+
186+
# All subscribers should receive the message
187+
assert len(subscriber1_messages) == 1
188+
assert len(subscriber2_messages) == 1
189+
assert len(subscriber3_messages) == 1
190+
assert subscriber1_messages[0] == TEST_REQUEST
191+
assert subscriber2_messages[0] == TEST_REQUEST
192+
assert subscriber3_messages[0] == TEST_REQUEST
193+
194+
# Send another message to all handlers
195+
handler1(ENCODER(TEST_RESPONSE))
196+
handler2(ENCODER(TEST_RESPONSE))
197+
handler3(ENCODER(TEST_RESPONSE))
198+
await asyncio.sleep(0.01) # yield
199+
200+
# All subscribers should have received both messages
201+
assert len(subscriber1_messages) == 2
202+
assert len(subscriber2_messages) == 2
203+
assert len(subscriber3_messages) == 2
204+
assert subscriber1_messages == [TEST_REQUEST, TEST_RESPONSE]
205+
assert subscriber2_messages == [TEST_REQUEST, TEST_RESPONSE]
206+
assert subscriber3_messages == [TEST_REQUEST, TEST_RESPONSE]
207+
208+
# Test unsubscribing one subscriber
209+
unsub1()
210+
211+
# Send another message only to remaining handlers
212+
handler2(ENCODER(TEST_REQUEST2))
213+
handler3(ENCODER(TEST_REQUEST2))
214+
await asyncio.sleep(0.01) # yield
215+
216+
# First subscriber should not have received the new message
217+
assert len(subscriber1_messages) == 2
218+
assert len(subscriber2_messages) == 3
219+
assert len(subscriber3_messages) == 3
220+
assert subscriber2_messages[2] == TEST_REQUEST2
221+
assert subscriber3_messages[2] == TEST_REQUEST2
222+
223+
# Unsubscribe remaining subscribers
224+
unsub2()
225+
unsub3()
226+
227+
228+
async def test_concurrent_subscribers_with_callback_exception(
229+
mqtt_session: Mock, mqtt_channel: MqttChannel, caplog: pytest.LogCaptureFixture
230+
) -> None:
231+
"""Test that exception in one subscriber callback doesn't affect others."""
232+
caplog.set_level(logging.ERROR)
233+
234+
def failing_callback(message: RoborockMessage) -> None:
235+
raise ValueError("Callback error")
236+
237+
subscriber2_messages: list[RoborockMessage] = []
238+
239+
unsub1 = await mqtt_channel.subscribe(failing_callback)
240+
unsub2 = await mqtt_channel.subscribe(subscriber2_messages.append)
241+
242+
# Get the message handlers
243+
handler1 = mqtt_session.subscribe.call_args_list[0][0][1]
244+
handler2 = mqtt_session.subscribe.call_args_list[1][0][1]
245+
246+
# Simulate receiving a message - first handler will raise exception
247+
handler1(ENCODER(TEST_REQUEST))
248+
handler2(ENCODER(TEST_REQUEST))
249+
await asyncio.sleep(0.01) # yield
250+
251+
# Exception should be logged but other subscribers should still work
252+
assert len(subscriber2_messages) == 1
253+
assert subscriber2_messages[0] == TEST_REQUEST
254+
255+
# Check that exception was logged
256+
error_records = [record for record in caplog.records if record.levelname == "ERROR"]
257+
assert len(error_records) == 1
258+
assert "Uncaught error in message handler callback" in error_records[0].message
259+
260+
# Unsubscribe all remaining subscribers
261+
unsub1()
262+
unsub2()

tests/devices/test_v1_channel.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,15 @@
6363

6464

6565
@pytest.fixture(name="mock_mqtt_channel")
66-
def setup_mock_mqtt_channel() -> FakeChannel:
66+
async def setup_mock_mqtt_channel() -> FakeChannel:
6767
"""Mock MQTT channel for testing."""
68-
return FakeChannel()
68+
channel = FakeChannel()
69+
await channel.connect()
70+
return channel
6971

7072

7173
@pytest.fixture(name="mock_local_channel")
72-
def setup_mock_local_channel() -> FakeChannel:
74+
async def setup_mock_local_channel() -> FakeChannel:
7375
"""Mock Local channel for testing."""
7476
return FakeChannel()
7577

@@ -150,6 +152,39 @@ async def test_v1_channel_subscribe_mqtt_only_success(
150152
assert not mock_mqtt_channel.subscribers
151153

152154

155+
async def test_v1_channel_mqtt_disconnected(
156+
v1_channel: V1Channel,
157+
mock_mqtt_channel: FakeChannel,
158+
mock_local_session: Mock,
159+
mock_local_channel: FakeChannel,
160+
) -> None:
161+
"""Test successful subscription with MQTT only (local connection fails)."""
162+
# Setup: MQTT succeeds, local fails
163+
mock_mqtt_channel.response_queue.append(TEST_NETWORK_INFO_RESPONSE)
164+
mock_local_channel.connect.side_effect = RoborockException("Connection failed")
165+
166+
callback = Mock()
167+
unsub = await v1_channel.subscribe(callback)
168+
169+
# Verify MQTT connection was established
170+
assert mock_mqtt_channel.subscribers
171+
172+
# Verify local connection was attempted but failed
173+
mock_local_session.assert_called_once_with(TEST_HOST)
174+
mock_local_channel.connect.assert_called_once()
175+
176+
# Simulate an MQTT disconnection where the channel is not healthy
177+
await mock_mqtt_channel.close()
178+
179+
# Verify properties
180+
assert not v1_channel.is_mqtt_connected
181+
assert not v1_channel.is_local_connected
182+
183+
# Test unsubscribe
184+
unsub()
185+
assert not mock_mqtt_channel.subscribers
186+
187+
153188
async def test_v1_channel_subscribe_both_connections_success(
154189
v1_channel: V1Channel,
155190
mock_mqtt_channel: Mock,

0 commit comments

Comments
 (0)