Skip to content

Commit 7705cda

Browse files
committed
chore: move broadcast_protocol to t's own file
1 parent d966b84 commit 7705cda

File tree

2 files changed

+72
-51
lines changed

2 files changed

+72
-51
lines changed

roborock/broadcast_protocol.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import json
5+
import logging
6+
from asyncio import BaseTransport, Lock
7+
8+
from construct import ( # type: ignore
9+
Bytes,
10+
Checksum,
11+
Int16ub,
12+
Int32ub,
13+
RawCopy,
14+
Struct,
15+
)
16+
17+
from roborock.containers import BroadcastMessage
18+
from roborock.protocol import EncryptionAdapter, Utils, _Parser
19+
20+
_LOGGER = logging.getLogger(__name__)
21+
22+
BROADCAST_TOKEN = b"qWKYcdQWrbm9hPqe"
23+
24+
25+
class RoborockProtocol(asyncio.DatagramProtocol):
26+
def __init__(self, timeout: int = 5):
27+
self.timeout = timeout
28+
self.transport: BaseTransport | None = None
29+
self.devices_found: list[BroadcastMessage] = []
30+
self._mutex = Lock()
31+
32+
def __del__(self):
33+
self.close()
34+
35+
def datagram_received(self, data, _):
36+
[broadcast_message], _ = BroadcastParser.parse(data)
37+
if broadcast_message.payload:
38+
parsed_message = BroadcastMessage.from_dict(json.loads(broadcast_message.payload))
39+
_LOGGER.debug(f"Received broadcast: {parsed_message}")
40+
self.devices_found.append(parsed_message)
41+
42+
async def discover(self):
43+
async with self._mutex:
44+
try:
45+
loop = asyncio.get_event_loop()
46+
self.transport, _ = await loop.create_datagram_endpoint(lambda: self, local_addr=("0.0.0.0", 58866))
47+
await asyncio.sleep(self.timeout)
48+
return self.devices_found
49+
finally:
50+
self.close()
51+
self.devices_found = []
52+
53+
def close(self):
54+
self.transport.close() if self.transport else None
55+
56+
57+
_BroadcastMessage = Struct(
58+
"message"
59+
/ RawCopy(
60+
Struct(
61+
"version" / Bytes(3),
62+
"seq" / Int32ub,
63+
"protocol" / Int16ub,
64+
"payload" / EncryptionAdapter(lambda ctx: BROADCAST_TOKEN),
65+
)
66+
),
67+
"checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
68+
)
69+
70+
71+
BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)

roborock/protocol.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import binascii
54
import gzip
65
import hashlib
7-
import json
86
import logging
9-
from asyncio import BaseTransport, Lock
107
from collections.abc import Callable
118
from urllib.parse import urlparse
129

@@ -31,7 +28,7 @@
3128
from Crypto.Cipher import AES
3229
from Crypto.Util.Padding import pad, unpad
3330

34-
from roborock.containers import BroadcastMessage, RRiot
31+
from roborock.containers import RRiot
3532
from roborock.exceptions import RoborockException
3633
from roborock.mqtt.session import MqttParams
3734
from roborock.roborock_message import RoborockMessage
@@ -40,7 +37,6 @@
4037
SALT = b"TXdfu$jyZ#TZHsg4"
4138
A01_HASH = "726f626f726f636b2d67a6d6da"
4239
B01_HASH = "5wwh9ikChRjASpMU8cxg7o1d2E"
43-
BROADCAST_TOKEN = b"qWKYcdQWrbm9hPqe"
4440
AP_CONFIG = 1
4541
SOCK_DISCOVERY = 2
4642

@@ -51,38 +47,6 @@ def md5hex(message: str) -> str:
5147
return md5.hexdigest()
5248

5349

54-
class RoborockProtocol(asyncio.DatagramProtocol):
55-
def __init__(self, timeout: int = 5):
56-
self.timeout = timeout
57-
self.transport: BaseTransport | None = None
58-
self.devices_found: list[BroadcastMessage] = []
59-
self._mutex = Lock()
60-
61-
def __del__(self):
62-
self.close()
63-
64-
def datagram_received(self, data, _):
65-
[broadcast_message], _ = BroadcastParser.parse(data)
66-
if broadcast_message.payload:
67-
parsed_message = BroadcastMessage.from_dict(json.loads(broadcast_message.payload))
68-
_LOGGER.debug(f"Received broadcast: {parsed_message}")
69-
self.devices_found.append(parsed_message)
70-
71-
async def discover(self):
72-
async with self._mutex:
73-
try:
74-
loop = asyncio.get_event_loop()
75-
self.transport, _ = await loop.create_datagram_endpoint(lambda: self, local_addr=("0.0.0.0", 58866))
76-
await asyncio.sleep(self.timeout)
77-
return self.devices_found
78-
finally:
79-
self.close()
80-
self.devices_found = []
81-
82-
def close(self):
83-
self.transport.close() if self.transport else None
84-
85-
8650
class Utils:
8751
"""Util class for protocol manipulation."""
8852

@@ -324,19 +288,6 @@ def _build(self, obj, stream, context, path):
324288
"remaining" / Optional(GreedyBytes),
325289
)
326290

327-
_BroadcastMessage = Struct(
328-
"message"
329-
/ RawCopy(
330-
Struct(
331-
"version" / Bytes(3),
332-
"seq" / Int32ub,
333-
"protocol" / Int16ub,
334-
"payload" / EncryptionAdapter(lambda ctx: BROADCAST_TOKEN),
335-
)
336-
),
337-
"checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
338-
)
339-
340291

341292
class _Parser:
342293
def __init__(self, con: Construct, required_local_key: bool):
@@ -390,7 +341,6 @@ def build(
390341

391342

392343
MessageParser: _Parser = _Parser(_Messages, True)
393-
BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)
394344

395345

396346
def create_mqtt_params(rriot: RRiot) -> MqttParams:

0 commit comments

Comments
 (0)