33import asyncio
44import json
55import logging
6- from collections .abc import Callable , Generator
6+ from collections .abc import AsyncGenerator , Callable , Generator
77from unittest .mock import AsyncMock , Mock , patch
88
99import pytest
@@ -63,23 +63,29 @@ def setup_mqtt_channel(mqtt_session: Mock) -> MqttChannel:
6363 )
6464
6565
66- @pytest .fixture (name = "received_messages " , autouse = True )
67- async def setup_subscribe_callback (mqtt_channel : MqttChannel ) -> list [RoborockMessage ]:
66+ @pytest .fixture (name = "mqtt_subscribers " , autouse = True )
67+ async def setup_subscribe_callback (mqtt_session : Mock ) -> AsyncGenerator [ list [Callable [[ bytes ], None ]], None ]:
6868 """Fixture to record messages received by the subscriber."""
69- messages : list [RoborockMessage ] = []
70- await mqtt_channel .subscribe (messages .append )
71- return messages
69+ subscriber_callbacks = []
70+
71+ def mock_subscribe (_ : str , callback : Callable [[bytes ], None ]) -> Callable [[], None ]:
72+ subscriber_callbacks .append (callback )
73+ return lambda : subscriber_callbacks .remove (callback )
74+
75+ mqtt_session .subscribe .side_effect = mock_subscribe
76+ yield subscriber_callbacks
77+ assert not subscriber_callbacks , "Not all subscribers were unsubscribed"
7278
7379
7480@pytest .fixture (name = "mqtt_message_handler" )
75- async def setup_message_handler (mqtt_session : Mock , mqtt_channel : MqttChannel ) -> Callable [[bytes ], None ]:
81+ async def setup_message_handler (mqtt_subscribers : list [ Callable [[ bytes ], None ]] ) -> Callable [[bytes ], None ]:
7682 """Fixture to allow simulating incoming MQTT messages."""
77- # Subscribe to set up message handling. We grab the message handler callback
78- # and use it to simulate receiving a response.
79- assert mqtt_session . subscribe
80- subscribe_call_args = mqtt_session . subscribe . call_args
81- message_handler = subscribe_call_args [ 0 ][ 1 ]
82- return message_handler
83+
84+ def invoke_all_callbacks ( message : bytes ) -> None :
85+ for callback in mqtt_subscribers :
86+ callback ( message )
87+
88+ return invoke_all_callbacks
8389
8490
8591@pytest .fixture
@@ -106,23 +112,6 @@ async def mock_home_data() -> HomeData:
106112 return HomeData .from_dict (mock_data .HOME_DATA_RAW )
107113
108114
109- async def test_mqtt_channel (mqtt_session : Mock , mqtt_channel : MqttChannel ) -> None :
110- """Test MQTT channel setup."""
111-
112- unsub = Mock ()
113- mqtt_session .subscribe .return_value = unsub
114-
115- callback = Mock ()
116- result = await mqtt_channel .subscribe (callback )
117-
118- assert mqtt_session .subscribe .called
119- assert mqtt_session .subscribe .call_args [0 ][0 ] == "rr/m/o/user123/username/abc123"
120-
121- unsub .assert_not_called ()
122- result ()
123- unsub .assert_called_once ()
124-
125-
126115async def test_publish_success (
127116 mqtt_session : Mock ,
128117 mqtt_channel : MqttChannel ,
@@ -148,13 +137,126 @@ async def test_publish_success(
148137
149138
150139async def test_message_decode_error (
140+ mqtt_channel : MqttChannel ,
151141 mqtt_message_handler : Callable [[bytes ], None ],
152142 caplog : pytest .LogCaptureFixture ,
153143) -> None :
154144 """Test an error during message decoding."""
145+ callback = Mock ()
146+ unsub = await mqtt_channel .subscribe (callback )
147+
155148 mqtt_message_handler (b"invalid_payload" )
156149 await asyncio .sleep (0.01 ) # yield
157150
158151 assert len (caplog .records ) == 1
159152 assert caplog .records [0 ].levelname == "WARNING"
160153 assert "Failed to decode MQTT message" in caplog .records [0 ].message
154+ unsub ()
155+
156+
157+ async def test_concurrent_subscribers (mqtt_session : Mock , mqtt_channel : MqttChannel ) -> None :
158+ """Test multiple concurrent subscribers receive all messages."""
159+ # Set up multiple subscribers
160+ subscriber1_messages : list [RoborockMessage ] = []
161+ subscriber2_messages : list [RoborockMessage ] = []
162+ subscriber3_messages : list [RoborockMessage ] = []
163+
164+ unsub1 = await mqtt_channel .subscribe (subscriber1_messages .append )
165+ unsub2 = await mqtt_channel .subscribe (subscriber2_messages .append )
166+ unsub3 = await mqtt_channel .subscribe (subscriber3_messages .append )
167+
168+ # Verify that each subscription creates a separate call to the MQTT session
169+ assert mqtt_session .subscribe .call_count == 3
170+
171+ # All subscriptions should be to the same topic
172+ for call in mqtt_session .subscribe .call_args_list :
173+ assert call [0 ][0 ] == "rr/m/o/user123/username/abc123"
174+
175+ # Get the message handlers for each subscriber
176+ handler1 = mqtt_session .subscribe .call_args_list [0 ][0 ][1 ]
177+ handler2 = mqtt_session .subscribe .call_args_list [1 ][0 ][1 ]
178+ handler3 = mqtt_session .subscribe .call_args_list [2 ][0 ][1 ]
179+
180+ # Simulate receiving messages - each handler should decode the message independently
181+ handler1 (ENCODER (TEST_REQUEST ))
182+ handler2 (ENCODER (TEST_REQUEST ))
183+ handler3 (ENCODER (TEST_REQUEST ))
184+ await asyncio .sleep (0.01 ) # yield
185+
186+ # All subscribers should receive the message
187+ assert len (subscriber1_messages ) == 1
188+ assert len (subscriber2_messages ) == 1
189+ assert len (subscriber3_messages ) == 1
190+ assert subscriber1_messages [0 ] == TEST_REQUEST
191+ assert subscriber2_messages [0 ] == TEST_REQUEST
192+ assert subscriber3_messages [0 ] == TEST_REQUEST
193+
194+ # Send another message to all handlers
195+ handler1 (ENCODER (TEST_RESPONSE ))
196+ handler2 (ENCODER (TEST_RESPONSE ))
197+ handler3 (ENCODER (TEST_RESPONSE ))
198+ await asyncio .sleep (0.01 ) # yield
199+
200+ # All subscribers should have received both messages
201+ assert len (subscriber1_messages ) == 2
202+ assert len (subscriber2_messages ) == 2
203+ assert len (subscriber3_messages ) == 2
204+ assert subscriber1_messages == [TEST_REQUEST , TEST_RESPONSE ]
205+ assert subscriber2_messages == [TEST_REQUEST , TEST_RESPONSE ]
206+ assert subscriber3_messages == [TEST_REQUEST , TEST_RESPONSE ]
207+
208+ # Test unsubscribing one subscriber
209+ unsub1 ()
210+
211+ # Send another message only to remaining handlers
212+ handler2 (ENCODER (TEST_REQUEST2 ))
213+ handler3 (ENCODER (TEST_REQUEST2 ))
214+ await asyncio .sleep (0.01 ) # yield
215+
216+ # First subscriber should not have received the new message
217+ assert len (subscriber1_messages ) == 2
218+ assert len (subscriber2_messages ) == 3
219+ assert len (subscriber3_messages ) == 3
220+ assert subscriber2_messages [2 ] == TEST_REQUEST2
221+ assert subscriber3_messages [2 ] == TEST_REQUEST2
222+
223+ # Unsubscribe remaining subscribers
224+ unsub2 ()
225+ unsub3 ()
226+
227+
228+ async def test_concurrent_subscribers_with_callback_exception (
229+ mqtt_session : Mock , mqtt_channel : MqttChannel , caplog : pytest .LogCaptureFixture
230+ ) -> None :
231+ """Test that exception in one subscriber callback doesn't affect others."""
232+ caplog .set_level (logging .ERROR )
233+
234+ def failing_callback (message : RoborockMessage ) -> None :
235+ raise ValueError ("Callback error" )
236+
237+ subscriber2_messages : list [RoborockMessage ] = []
238+
239+ unsub1 = await mqtt_channel .subscribe (failing_callback )
240+ unsub2 = await mqtt_channel .subscribe (subscriber2_messages .append )
241+
242+ # Get the message handlers
243+ handler1 = mqtt_session .subscribe .call_args_list [0 ][0 ][1 ]
244+ handler2 = mqtt_session .subscribe .call_args_list [1 ][0 ][1 ]
245+
246+ # Simulate receiving a message - first handler will raise exception
247+ handler1 (ENCODER (TEST_REQUEST ))
248+ handler2 (ENCODER (TEST_REQUEST ))
249+ await asyncio .sleep (0.01 ) # yield
250+
251+ # Exception should be logged but other subscribers should still work
252+ assert len (subscriber2_messages ) == 1
253+ assert subscriber2_messages [0 ] == TEST_REQUEST
254+
255+ # Check that exception was logged
256+ error_records = [record for record in caplog .records if record .levelname == "ERROR" ]
257+ assert len (error_records ) == 1
258+ assert "Uncaught error in message handler callback" in error_records [0 ].message
259+
260+ # Unsubscribe all remaining subscribers
261+ unsub1 ()
262+ unsub2 ()
0 commit comments