22
33import asyncio
44import logging
5- from asyncio import Lock , Transport
5+ from asyncio import Lock , TimerHandle , Transport
66from typing import Optional
77
88import async_timeout
99
10- from . import DeviceData , RoborockBase
10+ from . import DeviceData
1111from .api import COMMANDS_SECURED , QUEUE_TIMEOUT , RoborockClient
12- from .exceptions import CommandVacuumError , RoborockConnectionException , RoborockException
12+ from .exceptions import (
13+ CommandVacuumError ,
14+ RoborockConnectionException ,
15+ RoborockException ,
16+ )
1317from .protocol import MessageParser
14- from .roborock_message import RoborockMessage
18+ from .roborock_message import RoborockMessage , RoborockMessageProtocol
1519from .roborock_typing import CommandInfoMap , RoborockCommand
1620from .util import get_running_loop_or_create_one
1721
@@ -30,7 +34,7 @@ def __init__(self, device_data: DeviceData):
3034 self .remaining = b""
3135 self .transport : Transport | None = None
3236 self ._mutex = Lock ()
33- self .keep_alive_func ()
37+ self .keep_alive_task : TimerHandle | None = None
3438
3539 def data_received (self , message ):
3640 if self .remaining :
@@ -46,33 +50,34 @@ def connection_lost(self, exc: Optional[Exception]):
4650 def is_connected (self ):
4751 return self .transport and self .transport .is_reading ()
4852
49- def keep_alive_func (self , _ = None ):
50- keep_alive_task = asyncio .gather (
51- asyncio .sleep (10 ),
52- self .ping (),
53- )
54- keep_alive_task .add_done_callback (self .keep_alive_func )
53+ async def keep_alive_func (self , _ = None ):
54+ await self .ping ()
55+ self .keep_alive_task = self .loop .call_later (10 , lambda : self .keep_alive_func ())
5556
5657 async def async_connect (self ) -> None :
5758 async with self ._mutex :
5859 try :
59- if not self .is_connected ():
60+ is_connected = self .is_connected ()
61+ if not is_connected :
6062 self .sync_disconnect ()
6163 async with async_timeout .timeout (QUEUE_TIMEOUT ):
6264 _LOGGER .info (f"Connecting to { self .host } " )
6365 self .transport , _ = await self .loop .create_connection ( # type: ignore
6466 lambda : self , self .host , 58867
6567 )
66- await self .hello ()
6768 _LOGGER .info (f"Connected to { self .host } " )
6869 except Exception as e :
69- _LOGGER .warning (f"Failed connecting to { self .host } : { e } " )
7070 raise RoborockConnectionException (f"Failed connecting to { self .host } " ) from e
71+ if not is_connected :
72+ await self .hello ()
73+ await self .keep_alive_func ()
7174
7275 def sync_disconnect (self ) -> None :
7376 if self .transport and self .loop .is_running ():
7477 _LOGGER .debug (f"Disconnecting from { self .host } " )
7578 self .transport .close ()
79+ if self .keep_alive_task :
80+ self .keep_alive_task .cancel ()
7681
7782 async def async_disconnect (self ) -> None :
7883 async with self ._mutex :
@@ -96,24 +101,28 @@ def build_roborock_message(self, method: RoborockCommand, params: Optional[list
96101 )
97102
98103 async def hello (self ):
99- request_id = 1
100- _LOGGER .debug (f"id={ request_id } Requesting method hello with None" )
101- try :
102- return await self .send_message (
103- RoborockMessage (protocol = 0 , payload = None , seq = request_id , version = b"1.0" , random = 22 )
104- )
105- except Exception as e :
106- _LOGGER .error (e )
104+ if self .is_connected ():
105+ request_id = 1
106+ protocol = RoborockMessageProtocol .HELLO_REQUEST
107+ _LOGGER .debug (f"id={ request_id } Requesting protocol { protocol .name } " )
108+ try :
109+ return await self .send_message (
110+ RoborockMessage (protocol = protocol , payload = None , seq = request_id , version = b"1.0" , random = 22 )
111+ )
112+ except Exception as e :
113+ _LOGGER .error (e )
107114
108115 async def ping (self ):
109- request_id = 2
110- _LOGGER .debug (f"id={ request_id } Requesting method ping with None" )
111- try :
112- return await self .send_message (
113- RoborockMessage (protocol = 2 , payload = None , seq = request_id , version = b"1.0" , random = 23 )
114- )
115- except Exception as e :
116- _LOGGER .error (e )
116+ if self .is_connected ():
117+ request_id = 2
118+ protocol = RoborockMessageProtocol .PING_REQUEST
119+ _LOGGER .debug (f"id={ request_id } Requesting protocol { protocol .name } " )
120+ try :
121+ return await self .send_message (
122+ RoborockMessage (protocol = protocol , payload = None , seq = request_id , version = b"1.0" , random = 23 )
123+ )
124+ except Exception as e :
125+ _LOGGER .error (e )
117126
118127 async def send_command (self , method : RoborockCommand , params : Optional [list | dict ] = None ):
119128 roborock_message = self .build_roborock_message (method , params )
@@ -133,8 +142,12 @@ async def async_local_response(self, roborock_message: RoborockMessage):
133142 (response , err ) = await self ._async_response (request_id , response_protocol )
134143 if err :
135144 raise CommandVacuumError ("" , err ) from err
136- method = roborock_message .get_method () if roborock_message .protocol != 2 else "ping"
137- _LOGGER .debug (f"id={ request_id } Response from method { method } : { response } " )
145+ method = (
146+ f"method { roborock_message .get_method ()} "
147+ if roborock_message .protocol == 4
148+ else f"protocol { RoborockMessageProtocol (roborock_message .protocol ).name } "
149+ )
150+ _LOGGER .debug (f"id={ request_id } Response from { method } : { response } " )
138151 return response
139152
140153 def _send_msg_raw (self , data : bytes ):
@@ -146,28 +159,22 @@ def _send_msg_raw(self, data: bytes):
146159 raise RoborockException (e ) from e
147160
148161 async def send_message (self , roborock_messages : list [RoborockMessage ] | RoborockMessage ):
149- try :
150- await self .validate_connection ()
151- if isinstance (roborock_messages , RoborockMessage ):
152- roborock_messages = [roborock_messages ]
153- local_key = self .device_info .device .local_key
154- msg = MessageParser .build (roborock_messages , local_key = local_key )
155- # Send the command to the Roborock device
156- self ._send_msg_raw (msg )
157-
158- responses = await asyncio .gather (
159- * [self .async_local_response (roborock_message ) for roborock_message in roborock_messages ],
160- return_exceptions = True ,
161- )
162- exception = next ((response for response in responses if isinstance (response , BaseException )), None )
163- if exception :
164- raise exception
165- is_cached = next (
166- (response for response in responses if isinstance (response , RoborockBase ) and response .is_cached ), None
167- )
168- if is_cached :
169- await self .async_disconnect ()
170- return responses
171- except Exception :
172- await self .async_disconnect ()
173- raise
162+ await self .validate_connection ()
163+ if isinstance (roborock_messages , RoborockMessage ):
164+ roborock_messages = [roborock_messages ]
165+ local_key = self .device_info .device .local_key
166+ msg = MessageParser .build (roborock_messages , local_key = local_key )
167+ # Send the command to the Roborock device
168+ self ._send_msg_raw (msg )
169+
170+ responses = await asyncio .gather (
171+ * [self .async_local_response (roborock_message ) for roborock_message in roborock_messages ],
172+ return_exceptions = True ,
173+ )
174+ exception = next (
175+ (response for response in responses if isinstance (response , BaseException )),
176+ None ,
177+ )
178+ if exception :
179+ raise exception
180+ return responses
0 commit comments