Skip to content

Commit f480b51

Browse files
committed
chore: some misc changes
1 parent 518ffca commit f480b51

File tree

5 files changed

+140
-34
lines changed

5 files changed

+140
-34
lines changed

roborock/device_trait.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import datetime
2+
from abc import ABC, abstractmethod
3+
from collections.abc import Awaitable, Callable
4+
from dataclasses import dataclass
5+
6+
from . import RoborockCommand
7+
from .containers import Consumable, DeviceFeatures, DnDTimer, RoborockBase
8+
9+
10+
@dataclass
11+
class DeviceTrait(ABC):
12+
handle_command: RoborockCommand
13+
_status_type: type[RoborockBase] = RoborockBase
14+
15+
def __init__(self, send_command: Callable[..., Awaitable[None]]):
16+
self.send_command = send_command
17+
self.status: RoborockBase | None = None
18+
self.subscriptions = []
19+
20+
@classmethod
21+
@abstractmethod
22+
def supported(cls, features: DeviceFeatures) -> bool:
23+
raise NotImplementedError
24+
25+
def on_message(self, data: dict) -> None:
26+
self.status = self._status_type.from_dict(data)
27+
for callback in self.subscriptions:
28+
callback(self.status)
29+
30+
def subscribe(self, callable: Callable):
31+
# Maybe needs to handle async too?
32+
self.subscriptions.append(callable)
33+
34+
@abstractmethod
35+
def get(self):
36+
raise NotImplementedError
37+
38+
39+
class DndTrait(DeviceTrait):
40+
handle_command: RoborockCommand = RoborockCommand.GET_DND_TIMER
41+
_status_type: type[DnDTimer] = DnDTimer
42+
status: DnDTimer
43+
44+
def __init__(self, send_command: Callable[..., Awaitable[None]]):
45+
super().__init__(send_command)
46+
47+
@classmethod
48+
def supported(cls, features: DeviceFeatures) -> bool:
49+
return features.is_support_custom_dnd
50+
51+
async def update_dnd(self, enabled: bool, start_time: datetime.time, end_time: datetime.time) -> None:
52+
if self.status.enabled and not enabled:
53+
await self.send_command(RoborockCommand.CLOSE_DND_TIMER)
54+
else:
55+
start = start_time if start_time is not None else self.status.start_time
56+
end = end_time if end_time is not None else self.status.end_time
57+
await self.send_command(RoborockCommand.SET_DND_TIMER, [start.hour, start.minute, end.hour, end.minute])
58+
59+
async def get(self) -> None:
60+
await self.send_command(RoborockCommand.GET_DND_TIMER)
61+
62+
63+
class ConsumableTrait(DeviceTrait):
64+
handle_command = RoborockCommand.GET_CONSUMABLE
65+
_status_type: type[Consumable] = DnDTimer
66+
status: Consumable
67+
68+
def __init__(self, send_command: Callable[..., Awaitable[None]]):
69+
super().__init__(send_command)
70+
71+
@classmethod
72+
def supported(cls, features: DeviceFeatures) -> bool:
73+
return True
74+
75+
async def reset_consumable(self, consumable: str) -> None:
76+
await self.send_command(RoborockCommand.RESET_CONSUMABLE, [consumable])
77+
78+
async def get(self) -> None:
79+
await self.send_command(RoborockCommand.GET_CONSUMABLE)

roborock/device_traits/__init__.py

Whitespace-only changes.

roborock/mqtt_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ async def unsubscribe(device: DeviceData):
8989
device_id = message.topic.value.split("/")[-1]
9090
device = device_map[device_id]
9191
message = MessageParser.parse(message.payload, device.device.local_key)
92-
callbacks[device_id](message)
92+
for m in message[0]:
93+
callbacks[device_id](m)
9394
except Exception:
9495
...
9596

roborock/roborock_device.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import time
77

88
from . import RoborockCommand
9-
from .containers import DeviceData, UserData
9+
from .containers import DeviceData, ModelStatus, S7MaxVStatus, Status, UserData
10+
from .device_trait import ConsumableTrait, DeviceTrait, DndTrait
1011
from .mqtt_manager import RoborockMqttManager
1112
from .protocol import MessageParser, Utils
1213
from .roborock_message import RoborockMessage, RoborockMessageProtocol
@@ -25,35 +26,34 @@ def __init__(self, user_data: UserData, device_info: DeviceData):
2526
self._local_endpoint = "abc"
2627
self._nonce = secrets.token_bytes(16)
2728
self.manager = RoborockMqttManager()
28-
self.update_commands = self.determine_supported_commands()
29-
30-
def determine_supported_commands(self):
31-
# All devices support these
32-
supported_commands = {
33-
RoborockCommand.GET_CONSUMABLE,
34-
RoborockCommand.GET_STATUS,
35-
RoborockCommand.GET_CLEAN_SUMMARY,
36-
}
37-
# Get what features we can from the feature_set info.
38-
39-
# If a command is not described in feature_set, we should just add it anyways and then let it fail on the first call and remove it.
40-
robot_new_features = int(self.device_info.device.feature_set)
41-
new_feature_info_str = self.device_info.device.new_feature_set
42-
if 33554432 & int(robot_new_features):
43-
supported_commands.add(RoborockCommand.GET_DUST_COLLECTION_MODE)
44-
if 2 & int(new_feature_info_str[-8:], 16):
45-
# TODO: May not be needed as i think this can just be found in Status, but just POC
46-
supported_commands.add(RoborockCommand.APP_GET_CLEAN_ESTIMATE_INFO)
47-
return supported_commands
29+
self._message_id_types: dict[int, DeviceTrait] = {}
30+
self._command_to_trait = {}
31+
self._all_supported_traits = []
32+
self._dnd_trait: DndTrait | None = self.determine_supported_traits(DndTrait)
33+
self._consumable_trait: ConsumableTrait | None = self.determine_supported_traits(ConsumableTrait)
34+
self._status_type: type[Status] = ModelStatus.get(device_info.model, S7MaxVStatus)
35+
36+
def determine_supported_traits(self, trait: type[DeviceTrait]):
37+
def _send_command(
38+
method: RoborockCommand | str, params: list | dict | int | None = None, use_cloud: bool = True
39+
):
40+
return self.send_message(method, params, use_cloud)
41+
42+
if trait.supported(self.device_info.device_features):
43+
trait_instance = trait(_send_command)
44+
self._all_supported_traits.append(trait(_send_command))
45+
self._command_to_trait[trait.handle_command] = trait_instance
46+
return trait_instance
47+
return None
4848

4949
async def connect(self):
5050
"""Connect via MQTT and Local if possible."""
5151
await self.manager.subscribe(self.user_data, self.device_info, self.on_message)
5252
await self.update()
5353

5454
async def update(self):
55-
for cmd in self.update_commands:
56-
await self.send_message(method=cmd)
55+
for trait in self._all_supported_traits:
56+
await trait.get()
5757

5858
def _get_payload(
5959
self,
@@ -91,7 +91,9 @@ async def send_message(
9191
request_id, timestamp, payload = self._get_payload(method, params, True, use_cloud)
9292
request_protocol = RoborockMessageProtocol.RPC_REQUEST
9393
roborock_message = RoborockMessage(timestamp=timestamp, protocol=request_protocol, payload=payload)
94-
94+
if request_id in self._message_id_types:
95+
raise Exception("Duplicate id!")
96+
self._message_id_types[request_id] = self._command_to_trait[method]
9597
local_key = self.device_info.device.local_key
9698
msg = MessageParser.build(roborock_message, local_key, False)
9799
if use_cloud:
@@ -101,6 +103,19 @@ async def send_message(
101103
pass
102104

103105
def on_message(self, message: RoborockMessage):
106+
message_payload = message.get_payload()
107+
message_id = message.get_request_id()
108+
for data_point_number, data_point in message_payload.get("dps").items():
109+
if data_point_number == "102":
110+
data_point_response = json.loads(data_point)
111+
result = data_point_response.get("result")
112+
if isinstance(result, list) and len(result) == 1:
113+
result = result[0]
114+
if result and (trait := self._message_id_types.get(message_id)) is not None:
115+
trait.on_message(result)
116+
if (error := result.get("error")) is not None:
117+
print(error)
118+
print()
104119
# If message is command not supported - remove from self.update_commands
105120

106121
# If message is an error - log it?
@@ -115,3 +130,11 @@ def on_message(self, message: RoborockMessage):
115130

116131
# This should also probably be split with on_cloud_message and on_local_message.
117132
print(message)
133+
134+
@property
135+
def dnd(self) -> DndTrait | None:
136+
return self._dnd_trait
137+
138+
@property
139+
def consumable(self) -> ConsumableTrait | None:
140+
return self._consumable_trait

roborock/roborock_message.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,16 @@ class RoborockMessage:
161161
random: int = field(default_factory=lambda: get_next_int(10000, 99999))
162162
timestamp: int = field(default_factory=lambda: math.floor(time.time()))
163163
message_retry: MessageRetry | None = None
164+
_parsed_payload: dict | None = None
165+
166+
def get_payload(self) -> dict | None:
167+
if self.payload and not self._parsed_payload:
168+
self._parsed_payload = json.loads(self.payload.decode())
169+
return self._parsed_payload
164170

165171
def get_request_id(self) -> int | None:
166-
if self.payload:
167-
payload = json.loads(self.payload.decode())
168-
for data_point_number, data_point in payload.get("dps").items():
172+
if self._parsed_payload:
173+
for data_point_number, data_point in self._parsed_payload.get("dps").items():
169174
if data_point_number in ["101", "102"]:
170175
data_point_response = json.loads(data_point)
171176
return data_point_response.get("id")
@@ -180,19 +185,17 @@ def get_method(self) -> str | None:
180185
if self.message_retry:
181186
return self.message_retry.method
182187
protocol = self.protocol
183-
if self.payload and protocol in [4, 5, 101, 102]:
184-
payload = json.loads(self.payload.decode())
185-
for data_point_number, data_point in payload.get("dps").items():
188+
if self._parsed_payload and protocol in [4, 5, 101, 102]:
189+
for data_point_number, data_point in self._parsed_payload.get("dps").items():
186190
if data_point_number in ["101", "102"]:
187191
data_point_response = json.loads(data_point)
188192
return data_point_response.get("method")
189193
return None
190194

191195
def get_params(self) -> list | dict | None:
192196
protocol = self.protocol
193-
if self.payload and protocol in [4, 101, 102]:
194-
payload = json.loads(self.payload.decode())
195-
for data_point_number, data_point in payload.get("dps").items():
197+
if self._parsed_payload and protocol in [4, 101, 102]:
198+
for data_point_number, data_point in self._parsed_payload.get("dps").items():
196199
if data_point_number in ["101", "102"]:
197200
data_point_response = json.loads(data_point)
198201
return data_point_response.get("params")

0 commit comments

Comments
 (0)