Skip to content

Commit 8e9f1ba

Browse files
committed
feat: Add a DnD trait and fix bugs in the rpc channels
1 parent b227911 commit 8e9f1ba

File tree

11 files changed

+234
-24
lines changed

11 files changed

+234
-24
lines changed

roborock/containers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def from_dict(cls, data: dict[str, Any]):
134134
return None
135135
field_types = {field.name: field.type for field in dataclasses.fields(cls)}
136136
result: dict[str, Any] = {}
137-
for key, value in data.items():
138-
key = _decamelize(key)
137+
for orig_key, value in data.items():
138+
key = _decamelize(orig_key)
139139
if (field_type := field_types.get(key)) is None:
140140
continue
141141
if value == "None" or value is None:
@@ -178,16 +178,18 @@ class RoborockBaseTimer(RoborockBase):
178178
end_hour: int | None = None
179179
end_minute: int | None = None
180180
enabled: int | None = None
181-
start_time: datetime.time | None = None
182-
end_time: datetime.time | None = None
183181

184-
def __post_init__(self) -> None:
185-
self.start_time = (
182+
@property
183+
def start_time(self) -> datetime.time | None:
184+
return (
186185
datetime.time(hour=self.start_hour, minute=self.start_minute)
187186
if self.start_hour is not None and self.start_minute is not None
188187
else None
189188
)
190-
self.end_time = (
189+
190+
@property
191+
def end_time(self) -> datetime.time | None:
192+
return (
191193
datetime.time(hour=self.end_hour, minute=self.end_minute)
192194
if self.end_hour is not None and self.end_minute is not None
193195
else None

roborock/devices/device_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .channel import Channel
2323
from .mqtt_channel import create_mqtt_channel
2424
from .traits.b01.props import B01PropsApi
25+
from .traits.dnd import DoNotDisturbTrait
2526
from .traits.dyad import DyadApi
2627
from .traits.status import StatusTrait
2728
from .traits.trait import Trait
@@ -152,6 +153,7 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
152153
case DeviceVersion.V1:
153154
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
154155
traits.append(StatusTrait(product, channel.rpc_channel))
156+
traits.append(DoNotDisturbTrait(channel.rpc_channel))
155157
case DeviceVersion.A01:
156158
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
157159
match product.category:

roborock/devices/traits/dnd.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Module for Roborock V1 devices.
2+
3+
This interface is experimental and subject to breaking changes without notice
4+
until the API is stable.
5+
"""
6+
7+
import logging
8+
from collections.abc import Callable
9+
10+
from roborock.containers import DnDTimer
11+
from roborock.devices.v1_rpc_channel import V1RpcChannel
12+
from roborock.roborock_typing import RoborockCommand
13+
14+
from .trait import Trait
15+
16+
_LOGGER = logging.getLogger(__name__)
17+
18+
__all__ = [
19+
"DoNotDisturbTrait",
20+
]
21+
22+
23+
class DoNotDisturbTrait(Trait):
24+
"""Trait for managing Do Not Disturb (DND) settings on Roborock devices."""
25+
26+
name = "do_not_disturb"
27+
28+
def __init__(self, rpc_channel: Callable[[], V1RpcChannel]) -> None:
29+
"""Initialize the DoNotDisturbTrait."""
30+
self._rpc_channel = rpc_channel
31+
32+
async def get_dnd_timer(self) -> DnDTimer:
33+
"""Get the current Do Not Disturb (DND) timer settings of the device."""
34+
return await self._rpc_channel().send_command(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer)
35+
36+
async def set_dnd_timer(self, dnd_timer: DnDTimer) -> None:
37+
"""Set the Do Not Disturb (DND) timer settings of the device."""
38+
await self._rpc_channel().send_command(RoborockCommand.SET_DND_TIMER, params=dnd_timer.as_dict())
39+
40+
async def clear_dnd_timer(self) -> None:
41+
"""Clear the Do Not Disturb (DND) timer settings of the device."""
42+
await self._rpc_channel().send_command(RoborockCommand.CLOSE_DND_TIMER)

roborock/devices/traits/status.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,32 @@
55
"""
66

77
import logging
8+
from collections.abc import Callable
89

910
from roborock.containers import (
1011
HomeDataProduct,
1112
ModelStatus,
1213
S7MaxVStatus,
1314
Status,
1415
)
16+
from roborock.devices.v1_rpc_channel import V1RpcChannel
1517
from roborock.roborock_typing import RoborockCommand
1618

17-
from ..v1_rpc_channel import V1RpcChannel
1819
from .trait import Trait
1920

2021
_LOGGER = logging.getLogger(__name__)
2122

2223
__all__ = [
23-
"Status",
24+
"StatusTrait",
2425
]
2526

2627

2728
class StatusTrait(Trait):
28-
"""Unified Roborock device class with automatic connection setup."""
29+
"""Trait for managing the status of Roborock devices."""
2930

3031
name = "status"
3132

32-
def __init__(self, product_info: HomeDataProduct, rpc_channel: V1RpcChannel) -> None:
33+
def __init__(self, product_info: HomeDataProduct, rpc_channel: Callable[[], V1RpcChannel]) -> None:
3334
"""Initialize the StatusTrait."""
3435
self._product_info = product_info
3536
self._rpc_channel = rpc_channel
@@ -40,4 +41,4 @@ async def get_status(self) -> Status:
4041
This is a placeholder command and will likely be changed/moved in the future.
4142
"""
4243
status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus)
43-
return await self._rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type)
44+
return await self._rpc_channel().send_command(RoborockCommand.GET_STATUS, response_type=status_type)

roborock/devices/v1_channel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,12 @@ def is_mqtt_connected(self) -> bool:
8181
"""Return whether MQTT connection is available."""
8282
return self._mqtt_unsub is not None and self._mqtt_channel.is_connected
8383

84-
@property
8584
def rpc_channel(self) -> V1RpcChannel:
86-
"""Return the combined RPC channel prefers local with a fallback to MQTT."""
85+
"""Return the combined RPC channel prefers local with a fallback to MQTT.
86+
87+
This is dynamic based on the current connection status. That is, it may return
88+
a different channel depending on whether local or MQTT is available.
89+
"""
8790
return self._combined_rpc_channel or self._mqtt_rpc_channel
8891

8992
@property

roborock/devices/v1_rpc_channel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,21 @@ async def _send_raw_command(
132132
params: ParamsType = None,
133133
) -> Any:
134134
"""Send a command and return a parsed response RoborockBase type."""
135-
_LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
136135
request_message = RequestMessage(method, params=params)
136+
_LOGGER.debug(
137+
"Sending command (%s, request_id=%s): %s, params=%s", self._name, request_message.request_id, method, params
138+
)
137139
message = self._payload_encoder(request_message)
138140

139141
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
140142

141143
def find_response(response_message: RoborockMessage) -> None:
142144
try:
143145
decoded = decode_rpc_response(response_message)
144-
except RoborockException:
146+
except RoborockException as ex:
147+
_LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex)
145148
return
149+
_LOGGER.debug("Received response (request_id=%s): %s", self._name, decoded.request_id)
146150
if decoded.request_id == request_message.request_id:
147151
future.set_result(decoded.data)
148152

roborock/protocols/v1_protocol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class ResponseMessage:
109109
def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
110110
"""Decode a V1 RPC_RESPONSE message."""
111111
if not message.payload:
112-
raise RoborockException("Invalid V1 message format: missing payload")
112+
return ResponseMessage(request_id=message.seq, data={})
113113
try:
114114
payload = json.loads(message.payload.decode())
115115
except (json.JSONDecodeError, TypeError) as e:
@@ -141,6 +141,8 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
141141
_LOGGER.debug("Decoded V1 message result: %s", result)
142142
if isinstance(result, list) and result:
143143
result = result[0]
144+
if isinstance(result, str) and result == "ok":
145+
result = {}
144146
if not isinstance(result, dict):
145147
raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}")
146148
return ResponseMessage(request_id=request_id, data=result)

tests/devices/test_v1_channel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ async def test_v1_channel_send_command_local_preferred(
254254

255255
# Send command
256256
mock_local_channel.response_queue.append(TEST_RESPONSE)
257-
result = await v1_channel.rpc_channel.send_command(
257+
result = await v1_channel.rpc_channel().send_command(
258258
RoborockCommand.CHANGE_SOUND_VOLUME,
259259
response_type=S5MaxStatus,
260260
)
@@ -280,7 +280,7 @@ async def test_v1_channel_send_command_local_fails(
280280

281281
# Send command
282282
with pytest.raises(RoborockException, match="Local failed"):
283-
await v1_channel.rpc_channel.send_command(
283+
await v1_channel.rpc_channel().send_command(
284284
RoborockCommand.CHANGE_SOUND_VOLUME,
285285
response_type=S5MaxStatus,
286286
)
@@ -300,7 +300,7 @@ async def test_v1_channel_send_decoded_command_mqtt_only(
300300

301301
# Send command
302302
mock_mqtt_channel.response_queue.append(TEST_RESPONSE)
303-
result = await v1_channel.rpc_channel.send_command(
303+
result = await v1_channel.rpc_channel().send_command(
304304
RoborockCommand.CHANGE_SOUND_VOLUME,
305305
response_type=S5MaxStatus,
306306
)
@@ -322,7 +322,7 @@ async def test_v1_channel_send_decoded_command_with_params(
322322
# Send command with params
323323
mock_local_channel.response_queue.append(TEST_RESPONSE)
324324
test_params = {"volume": 80}
325-
await v1_channel.rpc_channel.send_command(
325+
await v1_channel.rpc_channel().send_command(
326326
RoborockCommand.CHANGE_SOUND_VOLUME,
327327
response_type=S5MaxStatus,
328328
params=test_params,
@@ -444,7 +444,7 @@ async def test_v1_channel_command_encoding_validation(
444444

445445
# Send local command and capture the request
446446
mock_local_channel.response_queue.append(TEST_RESPONSE_2)
447-
await v1_channel.rpc_channel.send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50})
447+
await v1_channel.rpc_channel().send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50})
448448
assert mock_local_channel.published_messages
449449
local_message = mock_local_channel.published_messages[0]
450450

@@ -512,7 +512,7 @@ async def test_v1_channel_full_subscribe_and_command_flow(
512512

513513
# Send a command (should use local)
514514
mock_local_channel.response_queue.append(TEST_RESPONSE)
515-
result = await v1_channel.rpc_channel.send_command(
515+
result = await v1_channel.rpc_channel().send_command(
516516
RoborockCommand.GET_STATUS,
517517
response_type=S5MaxStatus,
518518
)

tests/devices/test_v1_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def traits_fixture(rpc_channel: AsyncMock) -> list[Trait]:
4444
return [
4545
StatusTrait(
4646
product_info=HOME_DATA.products[0],
47-
rpc_channel=rpc_channel,
47+
rpc_channel=lambda: rpc_channel,
4848
)
4949
]
5050

tests/devices/traits/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for device traits."""

0 commit comments

Comments
 (0)