Skip to content

Commit e15c3be

Browse files
committed
chore: Separate V1 API connection logic from encoding logic
1 parent c1bdac0 commit e15c3be

File tree

9 files changed

+262
-126
lines changed

9 files changed

+262
-126
lines changed

roborock/devices/device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,4 @@ async def get_status(self) -> Status:
117117
This is a placeholder command and will likely be changed/moved in the future.
118118
"""
119119
status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus)
120-
return await self._v1_channel.send_decoded_command(RoborockCommand.GET_STATUS, response_type=status_type)
120+
return await self._v1_channel.rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type)

roborock/devices/local_channel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def __init__(self, host: str, local_key: str):
5050
self._encoder: Encoder = create_local_encoder(local_key)
5151
self._queue_lock = asyncio.Lock()
5252

53+
@property
54+
def is_connected(self) -> bool:
55+
"""Check if the channel is currently connected."""
56+
return self._is_connected
57+
5358
async def connect(self) -> None:
5459
"""Connect to the device."""
5560
if self._is_connected:
@@ -113,7 +118,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
113118
else:
114119
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
115120

116-
async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
121+
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
117122
"""Send a command message and wait for the response message."""
118123
if not self._transport or not self._is_connected:
119124
raise RoborockConnectionException("Not connected to device")

roborock/devices/mqtt_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
8080
else:
8181
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)
8282

83-
async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
83+
async def send_message(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
8484
"""Send a command message and wait for the response message.
8585
8686
Returns the raw response message - caller is responsible for parsing.

roborock/devices/v1_channel.py

Lines changed: 18 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,21 @@
66

77
import logging
88
from collections.abc import Callable
9-
from typing import Any, TypeVar
9+
from typing import TypeVar
1010

1111
from roborock.containers import HomeDataDevice, NetworkInfo, RoborockBase, UserData
1212
from roborock.exceptions import RoborockException
1313
from roborock.mqtt.session import MqttParams, MqttSession
1414
from roborock.protocols.v1_protocol import (
15-
CommandType,
16-
ParamsType,
1715
SecurityData,
18-
create_mqtt_payload_encoder,
1916
create_security_data,
20-
decode_rpc_response,
21-
encode_local_payload,
2217
)
2318
from roborock.roborock_message import RoborockMessage
2419
from roborock.roborock_typing import RoborockCommand
2520

2621
from .local_channel import LocalChannel, LocalSession, create_local_session
2722
from .mqtt_channel import MqttChannel
23+
from .v1_rpc_channel import V1RpcChannel, create_combined_rpc_channel, create_mqtt_rpc_channel
2824

2925
_LOGGER = logging.getLogger(__name__)
3026

@@ -58,9 +54,10 @@ def __init__(
5854
"""
5955
self._device_uid = device_uid
6056
self._mqtt_channel = mqtt_channel
61-
self._mqtt_payload_encoder = create_mqtt_payload_encoder(security_data)
57+
self._mqtt_rpc_channel = create_mqtt_rpc_channel(mqtt_channel, security_data)
6258
self._local_session = local_session
6359
self._local_channel: LocalChannel | None = None
60+
self._combined_rpc_channel: V1RpcChannel | None = None
6461
self._mqtt_unsub: Callable[[], None] | None = None
6562
self._local_unsub: Callable[[], None] | None = None
6663
self._callback: Callable[[RoborockMessage], None] | None = None
@@ -76,6 +73,16 @@ def is_mqtt_connected(self) -> bool:
7673
"""Return whether MQTT connection is available."""
7774
return self._mqtt_unsub is not None
7875

76+
@property
77+
def rpc_channel(self) -> V1RpcChannel:
78+
"""Return the combined RPC channel prefers local with a fallback to MQTT."""
79+
return self._combined_rpc_channel or self._mqtt_rpc_channel
80+
81+
@property
82+
def mqtt_rpc_channel(self) -> V1RpcChannel:
83+
"""Return the MQTT RPC channel."""
84+
return self._mqtt_rpc_channel
85+
7986
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
8087
"""Subscribe to all messages from the device.
8188
@@ -119,7 +126,9 @@ async def _get_networking_info(self) -> NetworkInfo:
119126
This is a cloud only command used to get the local device's IP address.
120127
"""
121128
try:
122-
return await self._send_mqtt_decoded_command(RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo)
129+
return await self._mqtt_rpc_channel.send_command(
130+
RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo
131+
)
123132
except RoborockException as e:
124133
raise RoborockException(f"Network info failed for device {self._device_uid}") from e
125134

@@ -136,59 +145,9 @@ async def _local_connect(self) -> Callable[[], None]:
136145
except RoborockException as e:
137146
self._local_channel = None
138147
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
139-
148+
self._combined_rpc_channel = create_combined_rpc_channel(self._local_channel, self._mqtt_rpc_channel)
140149
return await self._local_channel.subscribe(self._on_local_message)
141150

142-
async def send_decoded_command(
143-
self,
144-
method: CommandType,
145-
*,
146-
response_type: type[_T],
147-
params: ParamsType = None,
148-
) -> _T:
149-
"""Send a command using the best available transport.
150-
151-
Will prefer local connection if available, falling back to MQTT.
152-
"""
153-
connection = "local" if self.is_local_connected else "mqtt"
154-
_LOGGER.debug("Sending command (%s): %s, params=%s", connection, method, params)
155-
if self._local_channel:
156-
return await self._send_local_decoded_command(method, response_type=response_type, params=params)
157-
return await self._send_mqtt_decoded_command(method, response_type=response_type, params=params)
158-
159-
async def _send_mqtt_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]:
160-
"""Send a raw command and return a raw unparsed response."""
161-
message = self._mqtt_payload_encoder(method, params)
162-
_LOGGER.debug("Sending MQTT message for device %s: %s", self._device_uid, message)
163-
response = await self._mqtt_channel.send_command(message)
164-
return decode_rpc_response(response)
165-
166-
async def _send_mqtt_decoded_command(
167-
self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None
168-
) -> _T:
169-
"""Send a command over MQTT and decode the response."""
170-
decoded_response = await self._send_mqtt_raw_command(method, params)
171-
return response_type.from_dict(decoded_response)
172-
173-
async def _send_local_raw_command(self, method: CommandType, params: ParamsType | None = None) -> dict[str, Any]:
174-
"""Send a raw command over local connection."""
175-
if not self._local_channel:
176-
raise RoborockException("Local channel is not connected")
177-
178-
message = encode_local_payload(method, params)
179-
_LOGGER.debug("Sending local message for device %s: %s", self._device_uid, message)
180-
response = await self._local_channel.send_command(message)
181-
return decode_rpc_response(response)
182-
183-
async def _send_local_decoded_command(
184-
self, method: CommandType, *, response_type: type[_T], params: ParamsType | None = None
185-
) -> _T:
186-
"""Send a command over local connection and decode the response."""
187-
if not self._local_channel:
188-
raise RoborockException("Local channel is not connected")
189-
decoded_response = await self._send_local_raw_command(method, params)
190-
return response_type.from_dict(decoded_response)
191-
192151
def _on_mqtt_message(self, message: RoborockMessage) -> None:
193152
"""Handle incoming MQTT messages."""
194153
_LOGGER.debug("V1Channel received MQTT message from device %s: %s", self._device_uid, message)

roborock/devices/v1_rpc_channel.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""V1 Rpc Channel for Roborock devices.
2+
3+
This is a wrapper around the V1 channel that provides a higher level interface
4+
for sending typed commands and receiving typed responses. This also provides
5+
a simple interface for sending commands and receiving responses over both MQTT
6+
and local connections, preferring local when available.
7+
"""
8+
9+
import logging
10+
from collections.abc import Callable
11+
from typing import Any, Protocol, TypeVar, overload
12+
13+
from roborock.containers import RoborockBase
14+
from roborock.protocols.v1_protocol import (
15+
CommandType,
16+
ParamsType,
17+
SecurityData,
18+
create_mqtt_payload_encoder,
19+
decode_rpc_response,
20+
encode_local_payload,
21+
)
22+
from roborock.roborock_message import RoborockMessage
23+
24+
from .local_channel import LocalChannel
25+
from .mqtt_channel import MqttChannel
26+
27+
_LOGGER = logging.getLogger(__name__)
28+
29+
30+
_T = TypeVar("_T", bound=RoborockBase)
31+
32+
33+
class V1RpcChannel(Protocol):
34+
"""Protocol for V1 RPC channels.
35+
36+
This is a wrapper around a raw channel that provides a high-level interface
37+
for sending commands and receiving responses.
38+
"""
39+
40+
@overload
41+
async def send_command(
42+
self,
43+
method: CommandType,
44+
*,
45+
params: ParamsType = None,
46+
) -> Any:
47+
"""Send a command and return a decoded response."""
48+
...
49+
50+
@overload
51+
async def send_command(
52+
self,
53+
method: CommandType,
54+
*,
55+
response_type: type[_T],
56+
params: ParamsType = None,
57+
) -> _T:
58+
"""Send a command and return a parsed response RoborockBase type."""
59+
...
60+
61+
async def send_command(
62+
self,
63+
method: CommandType,
64+
*,
65+
response_type: type[_T] | None = None,
66+
params: ParamsType = None,
67+
) -> _T | Any:
68+
"""Send a command and return either a decoded or parsed response."""
69+
...
70+
71+
72+
class BaseV1RpcChannel:
73+
"""Base implementation that provides the typed response logic."""
74+
75+
async def send_command(
76+
self,
77+
method: CommandType,
78+
*,
79+
response_type: type[_T] | None = None,
80+
params: ParamsType = None,
81+
) -> _T | Any:
82+
"""Send a command and return either a decoded or parsed response."""
83+
decoded_response = await self._send_raw_command(method, params=params)
84+
85+
if response_type is not None:
86+
return response_type.from_dict(decoded_response)
87+
return decoded_response
88+
89+
async def _send_raw_command(
90+
self,
91+
method: CommandType,
92+
*,
93+
params: ParamsType = None,
94+
) -> Any:
95+
"""Send a raw command and return the decoded response. Must be implemented by subclasses."""
96+
raise NotImplementedError
97+
98+
99+
class CombinedChannel(BaseV1RpcChannel):
100+
"""A V1 RPC channel that can use both local and MQTT channels, preferring local when available."""
101+
102+
def __init__(
103+
self, local_channel: LocalChannel, local_rpc_channel: V1RpcChannel, mqtt_channel: V1RpcChannel
104+
) -> None:
105+
"""Initialize the combined channel with local and MQTT channels."""
106+
self._local_channel = local_channel
107+
self._local_rpc_channel = local_rpc_channel
108+
self._mqtt_rpc_channel = mqtt_channel
109+
110+
async def _send_raw_command(
111+
self,
112+
method: CommandType,
113+
*,
114+
params: ParamsType = None,
115+
) -> Any:
116+
"""Send a command and return a parsed response RoborockBase type."""
117+
if self._local_channel.is_connected:
118+
return await self._local_rpc_channel.send_command(method, params=params)
119+
return await self._mqtt_rpc_channel.send_command(method, params=params)
120+
121+
122+
class PayloadEncodedV1Channel(BaseV1RpcChannel):
123+
"""Protocol for V1 channels that send encoded commands."""
124+
125+
def __init__(
126+
self,
127+
name: str,
128+
channel: MqttChannel | LocalChannel,
129+
payload_encoder: Callable[[CommandType, ParamsType], RoborockMessage],
130+
) -> None:
131+
"""Initialize the channel with a raw channel and an encoder function."""
132+
self._name = name
133+
self._channel = channel
134+
self._payload_encoder = payload_encoder
135+
136+
async def _send_raw_command(
137+
self,
138+
method: CommandType,
139+
*,
140+
params: ParamsType = None,
141+
) -> Any:
142+
"""Send a command and return a parsed response RoborockBase type."""
143+
_LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
144+
message = self._payload_encoder(method, params)
145+
response = await self._channel.send_message(message)
146+
return decode_rpc_response(response)
147+
148+
149+
def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
150+
"""Create a V1 RPC channel using an MQTT channel."""
151+
payload_encoder = create_mqtt_payload_encoder(security_data)
152+
return PayloadEncodedV1Channel("mqtt", mqtt_channel, payload_encoder)
153+
154+
155+
def create_combined_rpc_channel(local_channel: LocalChannel, mqtt_rpc_channel: V1RpcChannel) -> V1RpcChannel:
156+
"""Create a V1 RPC channel that combines local and MQTT channels."""
157+
local_rpc_channel = PayloadEncodedV1Channel("local", local_channel, encode_local_payload)
158+
return CombinedChannel(local_channel, local_rpc_channel, mqtt_rpc_channel)

tests/devices/test_device.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ async def test_device_connection(device: RoborockDevice, channel: AsyncMock) ->
5656
async def test_device_get_status_command(device: RoborockDevice, channel: AsyncMock) -> None:
5757
"""Test the device get_status command."""
5858
# Mock response for get_status command
59-
channel.send_decoded_command.return_value = STATUS
59+
channel.rpc_channel.send_command.return_value = STATUS
6060

6161
# Test get_status and verify the command was sent
6262
status = await device.get_status()
63-
assert channel.send_decoded_command.called
63+
assert channel.rpc_channel.send_command.called
6464

6565
# Verify the result
6666
assert status is not None

0 commit comments

Comments
 (0)