77
88from roborock .callbacks import CallbackList , decoder_callback
99from roborock .exceptions import RoborockConnectionException , RoborockException
10- from roborock .protocol import Decoder , Encoder , create_local_decoder , create_local_encoder
11- from roborock .roborock_message import RoborockMessage
10+ from roborock .protocol import create_local_decoder , create_local_encoder
11+ from roborock .roborock_message import RoborockMessage , RoborockMessageProtocol
1212
13+ from ..protocols .v1_protocol import LocalProtocolVersion
14+ from ..util import get_next_int
1315from .channel import Channel
1416
1517_LOGGER = logging .getLogger (__name__ )
1618_PORT = 58867
19+ _TIMEOUT = 10.0
1720
1821
1922@dataclass
@@ -39,18 +42,83 @@ class LocalChannel(Channel):
3942 format most parsing to higher-level components.
4043 """
4144
42- def __init__ (self , host : str , local_key : str ):
45+ def __init__ (self , host : str , local_key : str , local_protocol_version : LocalProtocolVersion | None = None ):
4346 self ._host = host
4447 self ._transport : asyncio .Transport | None = None
4548 self ._protocol : _LocalProtocol | None = None
4649 self ._subscribers : CallbackList [RoborockMessage ] = CallbackList (_LOGGER )
4750 self ._is_connected = False
48-
49- self ._decoder : Decoder = create_local_decoder (local_key )
50- self ._encoder : Encoder = create_local_encoder (local_key )
51+ self ._local_key = local_key
52+ self ._local_protocol_version = local_protocol_version
53+ self ._connect_nonce = get_next_int (10000 , 32767 )
54+ self ._ack_nonce : int | None = None
55+ self ._update_encoder_decoder ()
56+
57+ def _update_encoder_decoder (self ):
58+ self ._encoder = create_local_encoder (
59+ local_key = self ._local_key , connect_nonce = self ._connect_nonce , ack_nonce = self ._ack_nonce
60+ )
61+ self ._decoder = create_local_decoder (
62+ local_key = self ._local_key , connect_nonce = self ._connect_nonce , ack_nonce = self ._ack_nonce
63+ )
5164 # Callback to decode messages and dispatch to subscribers
5265 self ._data_received : Callable [[bytes ], None ] = decoder_callback (self ._decoder , self ._subscribers , _LOGGER )
5366
67+ async def _do_hello (self , local_protocol_version : LocalProtocolVersion ) -> bool :
68+ """Perform the initial handshaking."""
69+ _LOGGER .debug (
70+ "Attempting to use the %s protocol for client %s..." ,
71+ local_protocol_version ,
72+ self ._host ,
73+ )
74+ request = RoborockMessage (
75+ protocol = RoborockMessageProtocol .HELLO_REQUEST ,
76+ version = local_protocol_version .encode (),
77+ random = self ._connect_nonce ,
78+ seq = 1 ,
79+ )
80+ try :
81+ response = await self .send_message (
82+ roborock_message = request ,
83+ request_id = request .seq ,
84+ response_protocol = RoborockMessageProtocol .HELLO_RESPONSE ,
85+ )
86+ self ._ack_nonce = response .random
87+ self ._local_protocol_version = local_protocol_version
88+ self ._update_encoder_decoder ()
89+
90+ _LOGGER .debug (
91+ "Client %s speaks the %s protocol." ,
92+ self ._host ,
93+ local_protocol_version ,
94+ )
95+ return True
96+ except RoborockException as e :
97+ _LOGGER .debug (
98+ "Client %s did not respond or does not speak the %s protocol. %s" ,
99+ self ._host ,
100+ local_protocol_version ,
101+ e ,
102+ )
103+ return False
104+
105+ async def hello (self ):
106+ """Send hello to the device to negotiate protocol."""
107+ if self ._local_protocol_version :
108+ # version is forced - try it first, if it fails, try the opposite
109+ if not await self ._do_hello (self ._local_protocol_version ):
110+ if not await self ._do_hello (
111+ LocalProtocolVersion .V1
112+ if self ._local_protocol_version is not LocalProtocolVersion .V1
113+ else LocalProtocolVersion .L01
114+ ):
115+ raise RoborockException ("Failed to connect to device with any known protocol" )
116+ else :
117+ # try 1.0, then L01
118+ if not await self ._do_hello (LocalProtocolVersion .V1 ):
119+ if not await self ._do_hello (LocalProtocolVersion .L01 ):
120+ raise RoborockException ("Failed to connect to device with any known protocol" )
121+
54122 @property
55123 def is_connected (self ) -> bool :
56124 """Check if the channel is currently connected."""
@@ -113,6 +181,29 @@ async def publish(self, message: RoborockMessage) -> None:
113181 logging .exception ("Uncaught error sending command" )
114182 raise RoborockException (f"Failed to send message: { message } " ) from err
115183
184+ async def send_message (
185+ self ,
186+ roborock_message : RoborockMessage ,
187+ request_id : int ,
188+ response_protocol : int ,
189+ ) -> RoborockMessage :
190+ """Send a raw message and wait for a raw response."""
191+ future : asyncio .Future [RoborockMessage ] = asyncio .Future ()
192+
193+ def find_response (response_message : RoborockMessage ) -> None :
194+ if response_message .protocol == response_protocol and response_message .seq == request_id :
195+ future .set_result (response_message )
196+
197+ unsub = await self .subscribe (find_response )
198+ try :
199+ await self .publish (roborock_message )
200+ return await asyncio .wait_for (future , timeout = _TIMEOUT )
201+ except TimeoutError as ex :
202+ future .cancel ()
203+ raise RoborockException (f"Command timed out after { _TIMEOUT } s" ) from ex
204+ finally :
205+ unsub ()
206+
116207
117208# This module provides a factory function to create LocalChannel instances.
118209#
0 commit comments