44
55import asyncio
66import base64
7- import binascii
87import gzip
98import hashlib
109import hmac
1413import secrets
1514import struct
1615import time
16+ from random import randint
1717from typing import Any , Callable
1818
1919import aiohttp
2020from Crypto .Cipher import AES
21- from Crypto .Util .Padding import pad , unpad
21+ from Crypto .Util .Padding import unpad
2222
2323from roborock .exceptions import (
2424 RoborockException , RoborockTimeout , VacuumError ,
4040 DustCollectionMode ,
4141
4242)
43+ from .roborock_message import RoborockMessage
4344from .roborock_queue import RoborockQueue
4445from .typing import (
4546 RoborockDeviceProp ,
@@ -61,34 +62,23 @@ def md5hex(message: str) -> str:
6162 return md5 .hexdigest ()
6263
6364
64- def md5bin (message : str ) -> bytes :
65- md5 = hashlib .md5 ()
66- md5 .update (message .encode ())
67- return md5 .digest ()
68-
69-
70- def encode_timestamp (_timestamp : int ) -> str :
71- hex_value = f"{ _timestamp :x} " .zfill (8 )
72- return "" .join (list (map (lambda idx : hex_value [idx ], [5 , 6 , 3 , 7 , 1 , 2 , 0 , 4 ])))
73-
74-
7565class PreparedRequest :
7666 def __init__ (self , base_url : str , base_headers : dict = None ) -> None :
7767 self .base_url = base_url
7868 self .base_headers = base_headers or {}
7969
8070 async def request (
81- self , method : str , url : str , params = None , data = None , headers = None
71+ self , method : str , url : str , params = None , data = None , headers = None
8272 ) -> dict | list :
8373 _url = "/" .join (s .strip ("/" ) for s in [self .base_url , url ])
8474 _headers = {** self .base_headers , ** (headers or {})}
8575 async with aiohttp .ClientSession () as session :
8676 async with session .request (
87- method ,
88- _url ,
89- params = params ,
90- data = data ,
91- headers = _headers ,
77+ method ,
78+ _url ,
79+ params = params ,
80+ data = data ,
81+ headers = _headers ,
9282 ) as resp :
9383 return await resp .json ()
9484
@@ -97,99 +87,24 @@ class RoborockClient:
9787
9888 def __init__ (self , endpoint : str , devices_info : dict [str , RoborockDeviceInfo ]) -> None :
9989 self .devices_info = devices_info
100- self ._seq = 1
101- self ._random = 4711
102- self ._id_counter = 10000
10390 self ._salt = "TXdfu$jyZ#TZHsg4"
10491 self ._endpoint = endpoint
10592 self ._nonce = secrets .token_bytes (16 )
10693 self ._waiting_queue : dict [int , RoborockQueue ] = {}
107- self ._status_listeners : list [Callable [[str , str ], None ]] = []
94+ self ._status_listeners : list [Callable [[int , str ], None ]] = []
10895
109- def add_status_listener (self , callback : Callable [[str , str ], None ]):
96+ def add_status_listener (self , callback : Callable [[int , str ], None ]):
11097 self ._status_listeners .append (callback )
11198
11299 async def async_disconnect (self ) -> Any :
113100 raise NotImplementedError
114101
115- def _decode_msg (self , msg : bytes , local_key : str ) -> list [dict [str , Any ]]:
116- prefix = None
117- if msg [4 :7 ] == "1.0" .encode ():
118- prefix = int .from_bytes (msg [:4 ], 'big' )
119- msg = msg [4 :]
120- elif msg [0 :3 ] != "1.0" .encode ():
121- raise RoborockException (f"Unknown protocol version { msg [0 :3 ]} " )
122- if len (msg ) in [17 , 21 , 25 ]:
123- [version , request_id , random , timestamp , protocol ] = struct .unpack (
124- "!3sIIIH" , msg [0 :17 ]
125- )
126- return [{
127- "prefix" : prefix ,
128- "version" : version ,
129- "request_id" : request_id ,
130- "random" : random ,
131- "timestamp" : timestamp ,
132- "protocol" : protocol ,
133- }]
134- index = 0
135- [version , request_id , random , timestamp , protocol , payload_len ] = struct .unpack (
136- "!3sIIIHH" , msg [index :index + 19 ]
137- )
138- [payload , expected_crc32 ] = struct .unpack_from (f"!{ payload_len } sI" , msg , index + 19 )
139- if payload_len == 0 :
140- index += 21
141- else :
142- crc32 = binascii .crc32 (msg [index : index + 19 + payload_len ])
143- index += 23 + payload_len
144- if crc32 != expected_crc32 :
145- raise RoborockException (f"Wrong CRC32 { crc32 } , expected { expected_crc32 } " )
146- decrypted_payload = None
147- if payload :
148- aes_key = md5bin (encode_timestamp (timestamp ) + local_key + self ._salt )
149- decipher = AES .new (aes_key , AES .MODE_ECB )
150- decrypted_payload = unpad (decipher .decrypt (payload ), AES .block_size )
151- return [{
152- "prefix" : prefix ,
153- "version" : version ,
154- "request_id" : request_id ,
155- "random" : random ,
156- "timestamp" : timestamp ,
157- "protocol" : protocol ,
158- "payload" : decrypted_payload
159- }] + (self ._decode_msg (msg [index :], local_key ) if index < len (msg ) else [])
160-
161- def _encode_msg (self , device_id , request_id , protocol , timestamp , payload , prefix = None ) -> bytes :
162- local_key = self .devices_info [device_id ].device .local_key
163- aes_key = md5bin (encode_timestamp (timestamp ) + local_key + self ._salt )
164- cipher = AES .new (aes_key , AES .MODE_ECB )
165- encrypted = cipher .encrypt (pad (payload , AES .block_size ))
166- encrypted_len = len (encrypted )
167- values = [
168- "1.0" .encode (),
169- request_id ,
170- self ._random ,
171- timestamp ,
172- protocol ,
173- encrypted_len ,
174- encrypted
175- ]
176- if prefix :
177- values = [prefix ] + values
178- msg = struct .pack (
179- f"!{ 'I' if prefix else '' } 3sIIIHH{ encrypted_len } s" ,
180- * values
181- )
182- crc32 = binascii .crc32 (msg [4 :] if prefix else msg )
183- msg += struct .pack ("!I" , crc32 )
184- return msg
185-
186- async def on_message (self , device_id , msg ) -> None :
102+ async def on_message (self , messages : list [RoborockMessage ]) -> None :
187103 try :
188- messages = self ._decode_msg (msg , self .devices_info [device_id ].device .local_key )
189104 for data in messages :
190- protocol = data .get ( " protocol" )
105+ protocol = data .protocol
191106 if protocol == 102 or protocol == 4 :
192- payload = json .loads (data .get ( " payload" ) .decode ())
107+ payload = json .loads (data .payload .decode ())
193108 for data_point_number , data_point in payload .get ("dps" ).items ():
194109 if data_point_number == "102" :
195110 data_point_response = json .loads (data_point )
@@ -215,45 +130,30 @@ async def on_message(self, device_id, msg) -> None:
215130 await queue .async_put (
216131 (result , None ), timeout = QUEUE_TIMEOUT
217132 )
218- elif request_id < self ._id_counter :
219- _LOGGER .debug (
220- f"id={ request_id } Ignoring response: { data_point_response } "
221- )
222133 elif data_point_number == "121" :
223134 status = STATE_CODE_TO_STATUS .get (data_point )
224135 _LOGGER .debug (f"Status updated to { status } " )
225136 for listener in self ._status_listeners :
226- listener (device_id , status )
137+ listener (data . seq , status )
227138 else :
228139 _LOGGER .debug (
229140 f"Unknown data point number received { data_point_number } with { data_point } "
230141 )
231142 elif protocol == 301 :
232- payload = data .get ( " payload" ) [0 :24 ]
143+ payload = data .payload [0 :24 ]
233144 [endpoint , _ , request_id , _ ] = struct .unpack ("<15sBH6s" , payload )
234145 if endpoint .decode ().startswith (self ._endpoint ):
235146 iv = bytes (AES .block_size )
236147 decipher = AES .new (self ._nonce , AES .MODE_CBC , iv )
237148 decrypted = unpad (
238- decipher .decrypt (data .get ( " payload" ) [24 :]), AES .block_size
149+ decipher .decrypt (data .payload [24 :]), AES .block_size
239150 )
240151 decrypted = gzip .decompress (decrypted )
241152 queue = self ._waiting_queue .get (request_id )
242153 if queue :
243154 if isinstance (decrypted , list ):
244155 decrypted = decrypted [0 ]
245156 await queue .async_put ((decrypted , None ), timeout = QUEUE_TIMEOUT )
246- elif data .get ('request_id' ):
247- request_id = data .get ('request_id' )
248- queue = self ._waiting_queue .get (request_id )
249- if queue :
250- protocol = data .get ("protocol" )
251- if queue .protocol == protocol :
252- await queue .async_put ((None , None ), timeout = QUEUE_TIMEOUT )
253- elif request_id < self ._id_counter and protocol != 5 :
254- _LOGGER .debug (
255- f"id={ request_id } Ignoring response: { data } "
256- )
257157 except Exception as ex :
258158 _LOGGER .exception (ex )
259159
@@ -271,11 +171,10 @@ async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[
271171 del self ._waiting_queue [request_id ]
272172
273173 def _get_payload (
274- self , method : RoborockCommand , params : list = None , secured = False
174+ self , method : RoborockCommand , params : list = None , secured = False
275175 ):
276176 timestamp = math .floor (time .time ())
277- request_id = self ._id_counter
278- self ._id_counter += 1
177+ request_id = randint (10000 , 99999 )
279178 inner = {
280179 "id" : request_id ,
281180 "method" : method ,
@@ -298,7 +197,7 @@ def _get_payload(
298197 return request_id , timestamp , payload
299198
300199 async def send_command (
301- self , device_id : str , method : RoborockCommand , params : list = None
200+ self , device_id : str , method : RoborockCommand , params : list = None
302201 ):
303202 raise NotImplementedError
304203
@@ -374,7 +273,14 @@ async def get_dock_summary(self, device_id: str, dock_type: RoborockDockType) ->
374273 commands = [self .get_dust_collection_mode (device_id )]
375274 if dock_type == RoborockDockType .EMPTY_WASH_FILL_DOCK :
376275 commands += [self .get_wash_towel_mode (device_id ), self .get_smart_wash_params (device_id )]
377- [dust_collection_mode , wash_towel_mode , smart_wash_params ] = (list (await asyncio .gather (* commands )) + [None , None ])[:3 ]
276+ [
277+ dust_collection_mode ,
278+ wash_towel_mode ,
279+ smart_wash_params
280+ ] = (
281+ list (await asyncio .gather (* commands ))
282+ + [None , None ]
283+ )[:3 ]
378284
379285 return RoborockDockSummary (dust_collection_mode , wash_towel_mode , smart_wash_params )
380286 except RoborockTimeout as e :
0 commit comments