Skip to content

Commit 6308f30

Browse files
committed
feat: Add a local channel, similar to the MQTT channel
1 parent 509ff6a commit 6308f30

File tree

2 files changed

+469
-0
lines changed

2 files changed

+469
-0
lines changed

roborock/devices/local_channel.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""Module for communicating with Roborock devices over a local network."""
2+
3+
import asyncio
4+
import logging
5+
from collections.abc import Callable
6+
from dataclasses import dataclass
7+
from json import JSONDecodeError
8+
9+
from roborock.exceptions import RoborockConnectionException, RoborockException
10+
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
11+
from roborock.roborock_message import RoborockMessage
12+
13+
_LOGGER = logging.getLogger(__name__)
14+
_PORT = 58867
15+
16+
17+
@dataclass
18+
class _LocalProtocol(asyncio.Protocol):
19+
"""Callbacks for the Roborock local client transport."""
20+
21+
messages_cb: Callable[[bytes], None]
22+
connection_lost_cb: Callable[[Exception | None], None]
23+
24+
def data_received(self, data: bytes) -> None:
25+
"""Called when data is received from the transport."""
26+
self.messages_cb(data)
27+
28+
def connection_lost(self, exc: Exception | None) -> None:
29+
"""Called when the transport connection is lost."""
30+
self.connection_lost_cb(exc)
31+
32+
33+
class LocalChannel:
34+
"""Simple RPC-style channel for communicating with a device over a local network.
35+
36+
Handles request/response correlation and timeouts, but leaves message
37+
format most parsing to higher-level components.
38+
"""
39+
40+
def __init__(self, host: str, local_key: str):
41+
self._host = host
42+
self._transport: asyncio.Transport | None = None
43+
self._protocol: _LocalProtocol | None = None
44+
self._subscribers: list[Callable[[RoborockMessage], None]] = []
45+
self._is_connected = False
46+
47+
# RPC support
48+
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
49+
self._decoder: Decoder = create_local_decoder(local_key)
50+
self._encoder: Encoder = create_local_encoder(local_key)
51+
self._queue_lock = asyncio.Lock()
52+
53+
async def connect(self) -> None:
54+
"""Connect to the device."""
55+
if self._is_connected:
56+
_LOGGER.warning("Already connected")
57+
return
58+
_LOGGER.debug("Connecting to %s:%s", self._host, _PORT)
59+
loop = asyncio.get_running_loop()
60+
protocol = _LocalProtocol(self._data_received, self._connection_lost)
61+
try:
62+
self._transport, self._protocol = await loop.create_connection(lambda: protocol, self._host, _PORT)
63+
self._is_connected = True
64+
except OSError as e:
65+
raise RoborockConnectionException(f"Failed to connect to {self._host}:{_PORT}") from e
66+
67+
async def close(self) -> None:
68+
"""Disconnect from the device."""
69+
if self._transport:
70+
self._transport.close()
71+
self._is_connected = False
72+
73+
def _data_received(self, data: bytes) -> None:
74+
"""Handle incoming data from the transport."""
75+
if not (messages := self._decoder(data)):
76+
_LOGGER.warning("Failed to decode local message: %s", data)
77+
return
78+
for message in messages:
79+
_LOGGER.debug("Received message: %s", message)
80+
asyncio.create_task(self._resolve_future_with_lock(message))
81+
for callback in self._subscribers:
82+
try:
83+
callback(message)
84+
except Exception as e:
85+
_LOGGER.exception("Uncaught error in message handler callback: %s", e)
86+
87+
def _connection_lost(self, exc: Exception | None) -> None:
88+
"""Handle connection loss."""
89+
_LOGGER.warning("Connection lost to %s", self._host, exc_info=exc)
90+
self._transport = None
91+
self._is_connected = False
92+
93+
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
94+
"""Subscribe to all messages from the device."""
95+
self._subscribers.append(callback)
96+
97+
def unsubscribe() -> None:
98+
self._subscribers.remove(callback)
99+
100+
return unsubscribe
101+
102+
async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
103+
"""Resolve waiting future with proper locking."""
104+
if (request_id := message.get_request_id()) is None:
105+
_LOGGER.debug("Received message with no request_id")
106+
return
107+
async with self._queue_lock:
108+
if (future := self._waiting_queue.pop(request_id, None)) is not None:
109+
future.set_result(message)
110+
else:
111+
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
112+
113+
async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
114+
"""Send a command message and wait for the response message."""
115+
if not self._transport or not self._is_connected:
116+
raise RoborockConnectionException("Not connected to device")
117+
118+
try:
119+
if (request_id := message.get_request_id()) is None:
120+
raise RoborockException("Message must have a request_id for RPC calls")
121+
except (ValueError, JSONDecodeError) as err:
122+
_LOGGER.exception("Error getting request_id from message: %s", err)
123+
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err
124+
125+
future: asyncio.Future[RoborockMessage] = asyncio.Future()
126+
async with self._queue_lock:
127+
if request_id in self._waiting_queue:
128+
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
129+
self._waiting_queue[request_id] = future
130+
131+
try:
132+
encoded_msg = self._encoder(message)
133+
self._transport.write(encoded_msg)
134+
return await asyncio.wait_for(future, timeout=timeout)
135+
except asyncio.TimeoutError as ex:
136+
async with self._queue_lock:
137+
self._waiting_queue.pop(request_id, None)
138+
raise RoborockException(f"Command timed out after {timeout}s") from ex
139+
except Exception:
140+
logging.exception("Uncaught error sending command")
141+
async with self._queue_lock:
142+
self._waiting_queue.pop(request_id, None)
143+
raise

0 commit comments

Comments
 (0)