Skip to content

Commit a802f66

Browse files
Lash-LCopilot
andauthored
feat: Add pinging to local client (#627)
* feat: Add pinging to local client * fix: reset keep_alive_task to None * chore: copilot test * Initial plan * feat: Add comprehensive test coverage for keep-alive functionality Co-authored-by: Lash-L <20257911+Lash-L@users.noreply.github.com> * refactor: Address code review feedback on keep-alive tests Co-authored-by: Lash-L <20257911+Lash-L@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Lash-L <20257911+Lash-L@users.noreply.github.com> --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: Lash-L <20257911+Lash-L@users.noreply.github.com>
1 parent 799a5c4 commit a802f66

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

roborock/devices/local_channel.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_LOGGER = logging.getLogger(__name__)
1818
_PORT = 58867
1919
_TIMEOUT = 5.0
20+
_PING_INTERVAL = 10
2021

2122

2223
@dataclass
@@ -58,6 +59,7 @@ def __init__(self, host: str, local_key: str):
5859
self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER)
5960
self._is_connected = False
6061
self._local_protocol_version: LocalProtocolVersion | None = None
62+
self._keep_alive_task: asyncio.Task[None] | None = None
6163
self._update_encoder_decoder(
6264
LocalChannelParams(local_key=local_key, connect_nonce=get_next_int(10000, 32767), ack_nonce=None)
6365
)
@@ -132,6 +134,28 @@ async def _hello(self):
132134

133135
raise RoborockException("Failed to connect to device with any known protocol")
134136

137+
async def _ping(self) -> None:
138+
ping_message = RoborockMessage(
139+
protocol=RoborockMessageProtocol.PING_REQUEST, version=self.protocol_version.encode()
140+
)
141+
await self._send_message(
142+
roborock_message=ping_message,
143+
request_id=ping_message.seq,
144+
response_protocol=RoborockMessageProtocol.PING_RESPONSE,
145+
)
146+
147+
async def _keep_alive_loop(self) -> None:
148+
while self._is_connected:
149+
try:
150+
await asyncio.sleep(_PING_INTERVAL)
151+
if self._is_connected:
152+
await self._ping()
153+
except asyncio.CancelledError:
154+
break
155+
except Exception:
156+
_LOGGER.debug("Keep-alive ping failed", exc_info=True)
157+
# Retry next interval
158+
135159
@property
136160
def protocol_version(self) -> LocalProtocolVersion:
137161
"""Return the negotiated local protocol version, or a sensible default."""
@@ -166,6 +190,7 @@ async def connect(self) -> None:
166190
# Perform protocol negotiation
167191
try:
168192
await self._hello()
193+
self._keep_alive_task = asyncio.create_task(self._keep_alive_loop())
169194
except RoborockException:
170195
# If protocol negotiation fails, clean up the connection state
171196
self.close()
@@ -177,6 +202,9 @@ def _data_received(self, data: bytes) -> None:
177202

178203
def close(self) -> None:
179204
"""Disconnect from the device."""
205+
if self._keep_alive_task:
206+
self._keep_alive_task.cancel()
207+
self._keep_alive_task = None
180208
if self._transport:
181209
self._transport.close()
182210
else:
@@ -187,6 +215,9 @@ def close(self) -> None:
187215
def _connection_lost(self, exc: Exception | None) -> None:
188216
"""Handle connection loss."""
189217
_LOGGER.warning("Connection lost to %s", self._host, exc_info=exc)
218+
if self._keep_alive_task:
219+
self._keep_alive_task.cancel()
220+
self._keep_alive_task = None
190221
self._transport = None
191222
self._is_connected = False
192223

tests/devices/test_local_channel.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,112 @@ async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalCh
366366
assert channel._params is not None
367367
assert channel._params.ack_nonce == 11111
368368
assert channel._is_connected is True
369+
370+
371+
async def test_keep_alive_task_created_on_connect(local_channel: LocalChannel, mock_loop: Mock) -> None:
372+
"""Test that _keep_alive_task is created when connect() is called."""
373+
# Before connecting, task should be None
374+
assert local_channel._keep_alive_task is None
375+
376+
await local_channel.connect()
377+
378+
# After connecting, task should be created and not done
379+
assert local_channel._keep_alive_task is not None
380+
assert isinstance(local_channel._keep_alive_task, asyncio.Task)
381+
assert not local_channel._keep_alive_task.done()
382+
383+
384+
async def test_keep_alive_task_canceled_on_close(local_channel: LocalChannel, mock_loop: Mock) -> None:
385+
"""Test that the keep-alive task is properly canceled when close() is called."""
386+
await local_channel.connect()
387+
388+
# Verify task exists
389+
task = local_channel._keep_alive_task
390+
assert task is not None
391+
assert not task.done()
392+
393+
# Close the connection
394+
local_channel.close()
395+
396+
# Give the task a moment to be cancelled
397+
await asyncio.sleep(0.01)
398+
399+
# Task should be canceled and reset to None
400+
assert task.cancelled() or task.done()
401+
assert local_channel._keep_alive_task is None
402+
403+
404+
async def test_keep_alive_task_canceled_on_connection_lost(local_channel: LocalChannel, mock_loop: Mock) -> None:
405+
"""Test that the keep-alive task is properly canceled when _connection_lost() is called."""
406+
await local_channel.connect()
407+
408+
# Verify task exists
409+
task = local_channel._keep_alive_task
410+
assert task is not None
411+
assert not task.done()
412+
413+
# Simulate connection loss
414+
local_channel._connection_lost(None)
415+
416+
# Give the task a moment to be cancelled
417+
await asyncio.sleep(0.01)
418+
419+
# Task should be canceled and reset to None
420+
assert task.cancelled() or task.done()
421+
assert local_channel._keep_alive_task is None
422+
423+
424+
async def test_keep_alive_ping_loop_executes_periodically(local_channel: LocalChannel, mock_loop: Mock) -> None:
425+
"""Test that the ping loop continues to execute periodically while connected."""
426+
await local_channel.connect()
427+
428+
# Verify the task is running and connected
429+
assert local_channel._keep_alive_task is not None
430+
assert not local_channel._keep_alive_task.done()
431+
assert local_channel._is_connected
432+
433+
434+
async def test_keep_alive_ping_exceptions_handled_gracefully(
435+
local_channel: LocalChannel, mock_loop: Mock, caplog: pytest.LogCaptureFixture
436+
) -> None:
437+
"""Test that exceptions in the ping loop are handled gracefully without stopping the loop."""
438+
from roborock.devices.local_channel import _PING_INTERVAL
439+
440+
# Set log level to capture DEBUG messages
441+
caplog.set_level("DEBUG")
442+
443+
ping_call_count = 0
444+
445+
# Mock the _ping method to always fail
446+
async def mock_ping() -> None:
447+
nonlocal ping_call_count
448+
ping_call_count += 1
449+
raise Exception("Test ping failure")
450+
451+
# Also need to mock asyncio.sleep to avoid waiting the full interval
452+
original_sleep = asyncio.sleep
453+
454+
async def mock_sleep(delay: float) -> None:
455+
# Only sleep briefly for test speed when waiting for ping interval
456+
if delay >= _PING_INTERVAL:
457+
await original_sleep(0.01)
458+
else:
459+
await original_sleep(delay)
460+
461+
with patch("asyncio.sleep", side_effect=mock_sleep):
462+
setattr(local_channel, "_ping", mock_ping)
463+
464+
await local_channel.connect()
465+
466+
# Wait for multiple ping attempts
467+
await original_sleep(0.1)
468+
469+
# Verify the task is still running despite the exception
470+
assert local_channel._keep_alive_task is not None
471+
assert not local_channel._keep_alive_task.done()
472+
473+
# Verify ping was called at least once
474+
assert ping_call_count >= 1
475+
476+
# Verify the exception was logged but didn't crash the loop
477+
assert any("Keep-alive ping failed" in record.message for record in caplog.records)

0 commit comments

Comments
 (0)