Skip to content

Commit 50dcf11

Browse files
committed
feat: add protocol updates
1 parent 3468b05 commit 50dcf11

File tree

4 files changed

+111
-9
lines changed

4 files changed

+111
-9
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ docs/_build/
1919

2020
# GitHub App credentials
2121
gha-creds-*.json
22+
23+
# pickle files
24+
*.p
25+
*.pickle

roborock/devices/device.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,22 @@
66

77
import asyncio
88
import datetime
9+
import json
910
import logging
1011
from abc import ABC
1112
from collections.abc import Callable
1213
from typing import Any
1314

1415
from roborock.callbacks import CallbackList
15-
from roborock.data import HomeDataDevice, HomeDataProduct
16+
from roborock.data import HomeDataDevice, HomeDataProduct, RoborockErrorCode, RoborockStateCode
1617
from roborock.diagnostics import redact_device_data
1718
from roborock.exceptions import RoborockException
18-
from roborock.roborock_message import RoborockMessage
19+
from roborock.roborock_message import (
20+
ROBOROCK_DATA_STATUS_PROTOCOL,
21+
RoborockDataProtocol,
22+
RoborockMessage,
23+
RoborockMessageProtocol,
24+
)
1925
from roborock.util import RoborockLoggerAdapter
2026

2127
from .traits import Trait
@@ -219,8 +225,77 @@ async def close(self) -> None:
219225
self._unsub = None
220226

221227
def _on_message(self, message: RoborockMessage) -> None:
222-
"""Handle incoming messages from the device."""
228+
"""Handle incoming messages from the device.
229+
230+
Note: Protocol updates (data points) are only sent via cloud/MQTT, not local connection.
231+
"""
223232
self._logger.debug("Received message from device: %s", message)
233+
if self.v1_properties is None:
234+
# Ensure we are only doing below logic for set-up V1 devices.
235+
return
236+
237+
# Only process messages that can contain protocol updates
238+
# RPC_RESPONSE (102), GENERAL_REQUEST (4), and GENERAL_RESPONSE (5)
239+
if message.protocol not in {
240+
RoborockMessageProtocol.RPC_RESPONSE,
241+
RoborockMessageProtocol.GENERAL_RESPONSE,
242+
}:
243+
return
244+
245+
if not message.payload:
246+
return
247+
248+
try:
249+
payload = json.loads(message.payload.decode())
250+
dps = payload.get("dps", {})
251+
252+
if not dps:
253+
return
254+
255+
# Process each data point in the message
256+
for data_point_number, data_point in dps.items():
257+
# Skip RPC responses (102) as they're handled by the RPC channel
258+
if data_point_number == "102":
259+
continue
260+
261+
try:
262+
data_protocol = RoborockDataProtocol(int(data_point_number))
263+
self._logger.debug(f"Got device update for {data_protocol.name}: {data_point}")
264+
self._handle_protocol_update(data_protocol, data_point)
265+
except ValueError:
266+
# Unknown protocol number
267+
self._logger.debug(
268+
f"Got unknown data protocol {data_point_number}, data: {data_point}. "
269+
f"This may allow for faster updates in the future."
270+
)
271+
except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex:
272+
self._logger.debug(f"Failed to parse protocol message: {ex}")
273+
274+
def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None:
275+
"""Handle a protocol update for a specific data protocol.
276+
277+
Args:
278+
protocol: The data protocol number.
279+
data_point: The data value for this protocol.
280+
"""
281+
# Handle status protocol updates
282+
if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.v1_properties and self.v1_properties.status:
283+
# Update the specific field in the status trait
284+
match protocol:
285+
case RoborockDataProtocol.ERROR_CODE:
286+
self.v1_properties.status.error_code = RoborockErrorCode(data_point)
287+
case RoborockDataProtocol.STATE:
288+
self.v1_properties.status.state = RoborockStateCode(data_point)
289+
case RoborockDataProtocol.BATTERY:
290+
self.v1_properties.status.battery = data_point
291+
case RoborockDataProtocol.CHARGE_STATUS:
292+
self.v1_properties.status.charge_status = data_point
293+
case _:
294+
# There is also fan power and water box mode, but for now those are skipped
295+
return
296+
297+
self._logger.debug("Updated status.%s to %s", protocol.name.lower(), data_point)
298+
self.v1_properties.status.notify_update()
224299

225300
def diagnostic_data(self) -> dict[str, Any]:
226301
"""Return diagnostics information about the device."""

roborock/devices/rpc/v1_channel.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,14 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
305305
loop = asyncio.get_running_loop()
306306
self._reconnect_task = loop.create_task(self._background_reconnect())
307307

308-
if not self.is_local_connected:
309-
# We were not able to connect locally, so fallback to MQTT and at least
310-
# establish that connection explicitly. If this fails then raise an
311-
# error and let the caller know we failed to subscribe.
312-
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
313-
self._logger.debug("V1Channel connected to device via MQTT")
308+
# Always subscribe to MQTT to receive protocol updates (data points)
309+
# even if we have a local connection. Protocol updates only come via cloud/MQTT.
310+
# Local connection is used for RPC commands, but push notifications come via MQTT.
311+
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
312+
if self.is_local_connected:
313+
self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)")
314+
else:
315+
self._logger.debug("V1Channel connected via MQTT only")
314316

315317
def unsub() -> None:
316318
"""Unsubscribe from all messages."""

roborock/devices/traits/v1/common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,23 @@
33
This is an internal library and should not be used directly by consumers.
44
"""
55

6+
from __future__ import annotations
7+
68
import logging
79
from abc import ABC, abstractmethod
10+
from collections.abc import Callable
811
from dataclasses import dataclass, fields
912
from typing import ClassVar, Self
1013

14+
from roborock.callbacks import CallbackList
1115
from roborock.data import RoborockBase
1216
from roborock.protocols.v1_protocol import V1RpcChannel
1317
from roborock.roborock_typing import RoborockCommand
1418

1519
_LOGGER = logging.getLogger(__name__)
1620

1721
V1ResponseData = dict | list | int | str
22+
V1TraitUpdateCallback = Callable[["V1TraitMixin"], None]
1823

1924

2025
@dataclass
@@ -74,6 +79,7 @@ def __post_init__(self) -> None:
7479
device setup code.
7580
"""
7681
self._rpc_channel = None
82+
self._update_callbacks: CallbackList[V1TraitMixin] = CallbackList()
7783

7884
@property
7985
def rpc_channel(self) -> V1RpcChannel:
@@ -97,6 +103,21 @@ def _update_trait_values(self, new_data: RoborockBase) -> None:
97103
new_value = getattr(new_data, field.name, None)
98104
setattr(self, field.name, new_value)
99105

106+
def add_update_callback(self, callback: V1TraitUpdateCallback) -> Callable[[], None]:
107+
"""Add a callback to be notified when the trait is updated.
108+
109+
The callback will be called with the updated trait instance whenever
110+
a protocol message updates the trait.
111+
112+
Returns:
113+
A callable that can be used to remove the callback.
114+
"""
115+
return self._update_callbacks.add_callback(callback)
116+
117+
def notify_update(self) -> None:
118+
"""Notify all registered callbacks that the trait has been updated."""
119+
self._update_callbacks(self)
120+
100121

101122
def _get_value_field(clazz: type[V1TraitMixin]) -> str:
102123
"""Get the name of the field marked as the main value of the RoborockValueBase."""

0 commit comments

Comments
 (0)