Skip to content

Commit ad3afc0

Browse files
fix: mapping prefix for all known commands
1 parent 17e72c3 commit ad3afc0

File tree

5 files changed

+394
-189
lines changed

5 files changed

+394
-189
lines changed

roborock/api.py

Lines changed: 28 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import asyncio
66
import base64
7-
import binascii
87
import gzip
98
import hashlib
109
import hmac
@@ -14,11 +13,12 @@
1413
import secrets
1514
import struct
1615
import time
16+
from random import randint
1717
from typing import Any, Callable
1818

1919
import aiohttp
2020
from Crypto.Cipher import AES
21-
from Crypto.Util.Padding import pad, unpad
21+
from Crypto.Util.Padding import unpad
2222

2323
from roborock.exceptions import (
2424
RoborockException, RoborockTimeout, VacuumError,
@@ -40,6 +40,7 @@
4040
DustCollectionMode,
4141

4242
)
43+
from .roborock_message import RoborockMessage
4344
from .roborock_queue import RoborockQueue
4445
from .typing import (
4546
RoborockDeviceProp,
@@ -61,34 +62,23 @@ def md5hex(message: str) -> str:
6162
return md5.hexdigest()
6263

6364

64-
def md5bin(message: str) -> bytes:
65-
md5 = hashlib.md5()
66-
md5.update(message.encode())
67-
return md5.digest()
68-
69-
70-
def encode_timestamp(_timestamp: int) -> str:
71-
hex_value = f"{_timestamp:x}".zfill(8)
72-
return "".join(list(map(lambda idx: hex_value[idx], [5, 6, 3, 7, 1, 2, 0, 4])))
73-
74-
7565
class PreparedRequest:
7666
def __init__(self, base_url: str, base_headers: dict = None) -> None:
7767
self.base_url = base_url
7868
self.base_headers = base_headers or {}
7969

8070
async def request(
81-
self, method: str, url: str, params=None, data=None, headers=None
71+
self, method: str, url: str, params=None, data=None, headers=None
8272
) -> dict | list:
8373
_url = "/".join(s.strip("/") for s in [self.base_url, url])
8474
_headers = {**self.base_headers, **(headers or {})}
8575
async with aiohttp.ClientSession() as session:
8676
async with session.request(
87-
method,
88-
_url,
89-
params=params,
90-
data=data,
91-
headers=_headers,
77+
method,
78+
_url,
79+
params=params,
80+
data=data,
81+
headers=_headers,
9282
) as resp:
9383
return await resp.json()
9484

@@ -97,99 +87,24 @@ class RoborockClient:
9787

9888
def __init__(self, endpoint: str, devices_info: dict[str, RoborockDeviceInfo]) -> None:
9989
self.devices_info = devices_info
100-
self._seq = 1
101-
self._random = 4711
102-
self._id_counter = 10000
10390
self._salt = "TXdfu$jyZ#TZHsg4"
10491
self._endpoint = endpoint
10592
self._nonce = secrets.token_bytes(16)
10693
self._waiting_queue: dict[int, RoborockQueue] = {}
107-
self._status_listeners: list[Callable[[str, str], None]] = []
94+
self._status_listeners: list[Callable[[int, str], None]] = []
10895

109-
def add_status_listener(self, callback: Callable[[str, str], None]):
96+
def add_status_listener(self, callback: Callable[[int, str], None]):
11097
self._status_listeners.append(callback)
11198

11299
async def async_disconnect(self) -> Any:
113100
raise NotImplementedError
114101

115-
def _decode_msg(self, msg: bytes, local_key: str) -> list[dict[str, Any]]:
116-
prefix = None
117-
if msg[4:7] == "1.0".encode():
118-
prefix = int.from_bytes(msg[:4], 'big')
119-
msg = msg[4:]
120-
elif msg[0:3] != "1.0".encode():
121-
raise RoborockException(f"Unknown protocol version {msg[0:3]}")
122-
if len(msg) in [17, 21, 25]:
123-
[version, request_id, random, timestamp, protocol] = struct.unpack(
124-
"!3sIIIH", msg[0:17]
125-
)
126-
return [{
127-
"prefix": prefix,
128-
"version": version,
129-
"request_id": request_id,
130-
"random": random,
131-
"timestamp": timestamp,
132-
"protocol": protocol,
133-
}]
134-
index = 0
135-
[version, request_id, random, timestamp, protocol, payload_len] = struct.unpack(
136-
"!3sIIIHH", msg[index:index + 19]
137-
)
138-
[payload, expected_crc32] = struct.unpack_from(f"!{payload_len}sI", msg, index + 19)
139-
if payload_len == 0:
140-
index += 21
141-
else:
142-
crc32 = binascii.crc32(msg[index: index + 19 + payload_len])
143-
index += 23 + payload_len
144-
if crc32 != expected_crc32:
145-
raise RoborockException(f"Wrong CRC32 {crc32}, expected {expected_crc32}")
146-
decrypted_payload = None
147-
if payload:
148-
aes_key = md5bin(encode_timestamp(timestamp) + local_key + self._salt)
149-
decipher = AES.new(aes_key, AES.MODE_ECB)
150-
decrypted_payload = unpad(decipher.decrypt(payload), AES.block_size)
151-
return [{
152-
"prefix": prefix,
153-
"version": version,
154-
"request_id": request_id,
155-
"random": random,
156-
"timestamp": timestamp,
157-
"protocol": protocol,
158-
"payload": decrypted_payload
159-
}] + (self._decode_msg(msg[index:], local_key) if index < len(msg) else [])
160-
161-
def _encode_msg(self, device_id, request_id, protocol, timestamp, payload, prefix=None) -> bytes:
162-
local_key = self.devices_info[device_id].device.local_key
163-
aes_key = md5bin(encode_timestamp(timestamp) + local_key + self._salt)
164-
cipher = AES.new(aes_key, AES.MODE_ECB)
165-
encrypted = cipher.encrypt(pad(payload, AES.block_size))
166-
encrypted_len = len(encrypted)
167-
values = [
168-
"1.0".encode(),
169-
request_id,
170-
self._random,
171-
timestamp,
172-
protocol,
173-
encrypted_len,
174-
encrypted
175-
]
176-
if prefix:
177-
values = [prefix] + values
178-
msg = struct.pack(
179-
f"!{'I' if prefix else ''}3sIIIHH{encrypted_len}s",
180-
*values
181-
)
182-
crc32 = binascii.crc32(msg[4:] if prefix else msg)
183-
msg += struct.pack("!I", crc32)
184-
return msg
185-
186-
async def on_message(self, device_id, msg) -> None:
102+
async def on_message(self, messages: list[RoborockMessage]) -> None:
187103
try:
188-
messages = self._decode_msg(msg, self.devices_info[device_id].device.local_key)
189104
for data in messages:
190-
protocol = data.get("protocol")
105+
protocol = data.protocol
191106
if protocol == 102 or protocol == 4:
192-
payload = json.loads(data.get("payload").decode())
107+
payload = json.loads(data.payload.decode())
193108
for data_point_number, data_point in payload.get("dps").items():
194109
if data_point_number == "102":
195110
data_point_response = json.loads(data_point)
@@ -215,45 +130,30 @@ async def on_message(self, device_id, msg) -> None:
215130
await queue.async_put(
216131
(result, None), timeout=QUEUE_TIMEOUT
217132
)
218-
elif request_id < self._id_counter:
219-
_LOGGER.debug(
220-
f"id={request_id} Ignoring response: {data_point_response}"
221-
)
222133
elif data_point_number == "121":
223134
status = STATE_CODE_TO_STATUS.get(data_point)
224135
_LOGGER.debug(f"Status updated to {status}")
225136
for listener in self._status_listeners:
226-
listener(device_id, status)
137+
listener(data.seq, status)
227138
else:
228139
_LOGGER.debug(
229140
f"Unknown data point number received {data_point_number} with {data_point}"
230141
)
231142
elif protocol == 301:
232-
payload = data.get("payload")[0:24]
143+
payload = data.payload[0:24]
233144
[endpoint, _, request_id, _] = struct.unpack("<15sBH6s", payload)
234145
if endpoint.decode().startswith(self._endpoint):
235146
iv = bytes(AES.block_size)
236147
decipher = AES.new(self._nonce, AES.MODE_CBC, iv)
237148
decrypted = unpad(
238-
decipher.decrypt(data.get("payload")[24:]), AES.block_size
149+
decipher.decrypt(data.payload[24:]), AES.block_size
239150
)
240151
decrypted = gzip.decompress(decrypted)
241152
queue = self._waiting_queue.get(request_id)
242153
if queue:
243154
if isinstance(decrypted, list):
244155
decrypted = decrypted[0]
245156
await queue.async_put((decrypted, None), timeout=QUEUE_TIMEOUT)
246-
elif data.get('request_id'):
247-
request_id = data.get('request_id')
248-
queue = self._waiting_queue.get(request_id)
249-
if queue:
250-
protocol = data.get("protocol")
251-
if queue.protocol == protocol:
252-
await queue.async_put((None, None), timeout=QUEUE_TIMEOUT)
253-
elif request_id < self._id_counter and protocol != 5:
254-
_LOGGER.debug(
255-
f"id={request_id} Ignoring response: {data}"
256-
)
257157
except Exception as ex:
258158
_LOGGER.exception(ex)
259159

@@ -271,11 +171,10 @@ async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[
271171
del self._waiting_queue[request_id]
272172

273173
def _get_payload(
274-
self, method: RoborockCommand, params: list = None, secured=False
174+
self, method: RoborockCommand, params: list = None, secured=False
275175
):
276176
timestamp = math.floor(time.time())
277-
request_id = self._id_counter
278-
self._id_counter += 1
177+
request_id = randint(10000, 99999)
279178
inner = {
280179
"id": request_id,
281180
"method": method,
@@ -298,7 +197,7 @@ def _get_payload(
298197
return request_id, timestamp, payload
299198

300199
async def send_command(
301-
self, device_id: str, method: RoborockCommand, params: list = None
200+
self, device_id: str, method: RoborockCommand, params: list = None
302201
):
303202
raise NotImplementedError
304203

@@ -374,7 +273,14 @@ async def get_dock_summary(self, device_id: str, dock_type: RoborockDockType) ->
374273
commands = [self.get_dust_collection_mode(device_id)]
375274
if dock_type == RoborockDockType.EMPTY_WASH_FILL_DOCK:
376275
commands += [self.get_wash_towel_mode(device_id), self.get_smart_wash_params(device_id)]
377-
[dust_collection_mode, wash_towel_mode, smart_wash_params] = (list(await asyncio.gather(*commands)) + [None, None])[:3]
276+
[
277+
dust_collection_mode,
278+
wash_towel_mode,
279+
smart_wash_params
280+
] = (
281+
list(await asyncio.gather(*commands))
282+
+ [None, None]
283+
)[:3]
378284

379285
return RoborockDockSummary(dust_collection_mode, wash_towel_mode, smart_wash_params)
380286
except RoborockTimeout as e:

roborock/cloud_api.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
import paho.mqtt.client as mqtt
1212

13-
from roborock.api import md5hex, md5bin, RoborockClient, SPECIAL_COMMANDS
14-
from roborock.code_mappings import RoborockDockType
13+
from roborock.api import md5hex, RoborockClient, SPECIAL_COMMANDS
1514
from roborock.exceptions import (
1615
RoborockException,
1716
CommandVacuumError,
@@ -22,9 +21,8 @@
2221
RoborockDeviceInfo,
2322
)
2423
from .roborock_queue import RoborockQueue
25-
from .typing import (
26-
RoborockCommand, RoborockDeviceProp,
27-
)
24+
from .roborock_message import RoborockParser, md5bin, RoborockMessage
25+
from .typing import RoborockCommand
2826
from .util import run_in_executor
2927

3028
_LOGGER = logging.getLogger(__name__)
@@ -96,7 +94,8 @@ async def on_message(self, _client, _, msg, __=None) -> None:
9694
async with self._mutex:
9795
self._last_device_msg_in = mqtt.time_func()
9896
device_id = msg.topic.split("/").pop()
99-
await super().on_message(device_id, msg.payload)
97+
messages, _ = RoborockParser.decode(msg.payload, self.devices_info[device_id].device.local_key)
98+
await super().on_message(messages)
10099

101100
@run_in_executor()
102101
async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None) -> None:
@@ -193,7 +192,13 @@ async def send_command(
193192
_LOGGER.debug(f"id={request_id} Requesting method {method} with {params}")
194193
request_protocol = 101
195194
response_protocol = 301 if method in SPECIAL_COMMANDS else 102
196-
msg = super()._encode_msg(device_id, request_id, request_protocol, timestamp, payload)
195+
roborock_message = RoborockMessage(
196+
timestamp=timestamp,
197+
protocol=request_protocol,
198+
payload=payload
199+
)
200+
local_key = self.devices_info[device_id].device.local_key
201+
msg = RoborockParser.encode(roborock_message, local_key)
197202
self._send_msg_raw(device_id, msg)
198203
(response, err) = await self._async_response(request_id, response_protocol)
199204
if err:

0 commit comments

Comments
 (0)