Skip to content

Commit 4501307

Browse files
committed
fix: Update the messages callback to not mutate the protocol once created.
1 parent 278f54c commit 4501307

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

roborock/devices/local_channel.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from roborock.protocol import create_local_decoder, create_local_encoder
1111
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
1212

13+
14+
1315
from ..protocols.v1_protocol import LocalProtocolVersion
1416
from ..util import get_next_int
1517
from .channel import Channel
@@ -51,20 +53,26 @@ class LocalChannel(Channel):
5153
format most parsing to higher-level components.
5254
"""
5355

54-
_protocol_cache: dict[str, LocalProtocolVersion] = {}
5556

5657
def __init__(self, host: str, local_key: str):
5758
self._host = host
5859
self._transport: asyncio.Transport | None = None
5960
self._protocol: _LocalProtocol | None = None
6061
self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER)
6162
self._is_connected = False
62-
self._local_protocol_version: LocalProtocolVersion | None = self._protocol_cache.get(host)
63+
self._local_protocol_version: LocalProtocolVersion | None = None
6364
self._update_encoder_decoder(
6465
LocalChannelParams(local_key=local_key, connect_nonce=get_next_int(10000, 32767), ack_nonce=None)
6566
)
6667

67-
def _update_encoder_decoder(self, params: LocalChannelParams):
68+
def _update_encoder_decoder(self, params: LocalChannelParams) -> None:
69+
"""Update the encoder and decoder with new parameters.
70+
71+
This is invoked once with an initial set of values used for protocol
72+
negotiation. Once negotiation completes, it is updated again to set the
73+
correct nonces for the follow up communications and updates the encoder
74+
and decoder functions accordingly.
75+
"""
6876
self._params = params
6977
self._encoder = create_local_encoder(
7078
local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce
@@ -73,9 +81,7 @@ def _update_encoder_decoder(self, params: LocalChannelParams):
7381
local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce
7482
)
7583
# Callback to decode messages and dispatch to subscribers
76-
self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER)
77-
if self._protocol:
78-
self._protocol.messages_cb = self._data_received
84+
self._dispatch = decoder_callback(self._decoder, self._subscribers, _LOGGER)
7985

8086
async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None:
8187
"""Perform the initial handshaking and return encoder params if successful."""
@@ -125,7 +131,6 @@ async def _hello(self):
125131
if params is not None:
126132
self._local_protocol_version = version
127133
self._update_encoder_decoder(params)
128-
self._protocol_cache[self._host] = self._local_protocol_version
129134
return
130135

131136
raise RoborockException("Failed to connect to device with any known protocol")
@@ -169,6 +174,10 @@ async def connect(self) -> None:
169174
self.close()
170175
raise
171176

177+
def _data_received(self, data: bytes) -> None:
178+
"""Invoked when data is received on the stream."""
179+
self._dispatch(data)
180+
172181
def close(self) -> None:
173182
"""Disconnect from the device."""
174183
if self._transport:

0 commit comments

Comments
 (0)