Skip to content

Commit 01e6744

Browse files
committed
chore: Refactor to reuse the same payload functions
1 parent 1e7c4ef commit 01e6744

File tree

1 file changed

+35
-94
lines changed

1 file changed

+35
-94
lines changed

roborock/devices/v1_rpc_channel.py

Lines changed: 35 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
SecurityData,
2222
create_map_response_decoder,
2323
decode_rpc_response,
24+
MapResponse,
25+
ResponseMessage,
2426
)
2527
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
2628

@@ -114,155 +116,94 @@ async def _send_raw_command(
114116
raise RoborockException("No available connection to send command")
115117

116118

117-
class RpcPublisher:
118-
"""Helper to create send and receive messages on a channel."""
119+
class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
120+
"""Protocol for V1 channels that send encoded commands."""
119121

120122
def __init__(
121123
self,
122124
name: str,
123125
channel: MqttChannel | LocalChannel,
124126
payload_encoder: Callable[[RequestMessage], RoborockMessage],
127+
decoder: Callable[[RoborockMessage], ResponseMessage | MapResponse] | None = None,
125128
) -> None:
126-
"""Initialize the RPC publisher."""
129+
"""Initialize the channel with a raw channel and an encoder function."""
127130
self._name = name
128131
self._channel = channel
129132
self._payload_encoder = payload_encoder
130-
131-
async def publish_and_wait(
132-
self,
133-
request_message: RequestMessage,
134-
find_response: Callable[[RoborockMessage], None],
135-
future: asyncio.Future[_V],
136-
) -> _V:
137-
"""Helper to send a message and wait for a future to complete.
138-
139-
The find_response function will be called for each incoming message. The
140-
function should check the message and call future.set_result or
141-
future.set_exception as appropriate when the response is found.
142-
"""
143-
_LOGGER.debug(
144-
"Sending command (%s, request_id=%s): %s, params=%s",
145-
self._name,
146-
request_message.request_id,
147-
request_message.method,
148-
request_message.params,
149-
)
150-
message = self._payload_encoder(request_message)
151-
unsub = await self._channel.subscribe(find_response)
152-
try:
153-
await self._channel.publish(message)
154-
return await asyncio.wait_for(future, timeout=_TIMEOUT)
155-
except TimeoutError as ex:
156-
future.cancel()
157-
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
158-
finally:
159-
unsub()
160-
161-
162-
class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
163-
"""Protocol for V1 channels that send encoded commands."""
164-
165-
def __init__(self, publisher: RpcPublisher) -> None:
166-
"""Initialize the channel with a raw channel and an encoder function."""
167-
self._name = publisher._name
168-
self._publisher = publisher
133+
self._decoder = decoder
169134

170135
async def _send_raw_command(
171136
self,
172137
method: CommandType,
173138
*,
174139
params: ParamsType = None,
175-
) -> ResponseData:
140+
) -> ResponseData | bytes:
176141
"""Send a command and return a parsed response RoborockBase type."""
177142
request_message = RequestMessage(method, params=params)
143+
_LOGGER.debug(
144+
"Sending command (%s, request_id=%s): %s, params=%s", self._name, request_message.request_id, method, params
145+
)
146+
message = self._payload_encoder(request_message)
178147

179-
future: asyncio.Future[ResponseData] = asyncio.Future()
148+
future: asyncio.Future[ResponseData | bytes] = asyncio.Future()
180149

181150
def find_response(response_message: RoborockMessage) -> None:
182151
try:
183-
decoded = decode_rpc_response(response_message)
152+
decoded = self._decoder(response_message)
184153
except RoborockException as ex:
185154
_LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex)
186155
return
187156
_LOGGER.debug("Received response (%s, request_id=%s)", self._name, decoded.request_id)
188157
if decoded.request_id == request_message.request_id:
189-
if decoded.api_error:
158+
if isinstance(decoded, ResponseMessage) and decoded.api_error:
190159
future.set_exception(decoded.api_error)
191160
else:
192161
future.set_result(decoded.data)
193162

194-
return await self._publisher.publish_and_wait(request_message, find_response, future)
195-
196-
197-
class MapRpcChannel(BaseV1RpcChannel):
198-
"""A V1 RPC channel that fetches and decodes map data."""
199-
200-
def __init__(
201-
self,
202-
publisher: RpcPublisher,
203-
security_data: SecurityData,
204-
) -> None:
205-
"""Initialize the map RPC channel."""
206-
self._publisher = publisher
207-
self._decoder = create_map_response_decoder(security_data=security_data)
208-
209-
async def _send_raw_command(
210-
self,
211-
method: CommandType,
212-
*,
213-
params: ParamsType = None,
214-
) -> Any:
215-
"""Send a command and return a parsed response RoborockBase type."""
216-
request_message = RequestMessage(method, params=params)
217-
218-
future: asyncio.Future[bytes] = asyncio.Future()
219-
220-
def find_response(response_message: RoborockMessage) -> None:
221-
try:
222-
decoded = self._decoder(response_message)
223-
except RoborockException as ex:
224-
_LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex)
225-
return
226-
if decoded is None:
227-
return
228-
_LOGGER.debug("Received response (map), request_id=%s)", decoded.request_id)
229-
if decoded.request_id == request_message.request_id:
230-
future.set_result(decoded.data)
231-
232-
return await self._publisher.publish_and_wait(request_message, find_response, future)
233-
163+
message = self._payload_encoder(request_message)
164+
unsub = await self._channel.subscribe(find_response)
165+
try:
166+
await self._channel.publish(message)
167+
return await asyncio.wait_for(future, timeout=_TIMEOUT)
168+
except TimeoutError as ex:
169+
future.cancel()
170+
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
171+
finally:
172+
unsub()
234173

235174
def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
236175
"""Create a V1 RPC channel using an MQTT channel."""
237-
publisher = RpcPublisher(
176+
return PayloadEncodedV1RpcChannel(
238177
"mqtt",
239178
mqtt_channel,
240179
lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data),
180+
decode_rpc_response,
241181
)
242-
return PayloadEncodedV1RpcChannel(publisher)
243182

244183

245184
def create_local_rpc_channel(local_channel: LocalChannel) -> V1RpcChannel:
246185
"""Create a V1 RPC channel using a local channel."""
247-
publisher = RpcPublisher(
248-
"local", local_channel, lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST)
186+
return PayloadEncodedV1RpcChannel(
187+
"local",
188+
local_channel,
189+
lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST),
190+
decode_rpc_response,
249191
)
250-
return PayloadEncodedV1RpcChannel(publisher)
251192

252193

253194
def create_map_rpc_channel(
254195
mqtt_channel: MqttChannel,
255196
security_data: SecurityData,
256-
) -> MapRpcChannel:
197+
) -> V1RpcChannel:
257198
"""Create a V1 RPC channel that fetches map data.
258199
259200
This will prefer local channels when available, falling back to MQTT
260201
channels if not. If neither is available, an exception will be raised
261202
when trying to send a command.
262203
"""
263-
publisher = RpcPublisher(
204+
return PayloadEncodedV1RpcChannel(
264205
"map",
265206
mqtt_channel,
266207
lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data),
208+
create_map_response_decoder(security_data=security_data),
267209
)
268-
return MapRpcChannel(publisher, security_data)

0 commit comments

Comments
 (0)