|
21 | 21 | SecurityData, |
22 | 22 | create_map_response_decoder, |
23 | 23 | decode_rpc_response, |
| 24 | + MapResponse, |
| 25 | + ResponseMessage, |
24 | 26 | ) |
25 | 27 | from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol |
26 | 28 |
|
@@ -114,155 +116,94 @@ async def _send_raw_command( |
114 | 116 | raise RoborockException("No available connection to send command") |
115 | 117 |
|
116 | 118 |
|
117 | | -class RpcPublisher: |
118 | | - """Helper to create send and receive messages on a channel.""" |
| 119 | +class PayloadEncodedV1RpcChannel(BaseV1RpcChannel): |
| 120 | + """Protocol for V1 channels that send encoded commands.""" |
119 | 121 |
|
120 | 122 | def __init__( |
121 | 123 | self, |
122 | 124 | name: str, |
123 | 125 | channel: MqttChannel | LocalChannel, |
124 | 126 | payload_encoder: Callable[[RequestMessage], RoborockMessage], |
| 127 | + decoder: Callable[[RoborockMessage], ResponseMessage | MapResponse] | None = None, |
125 | 128 | ) -> None: |
126 | | - """Initialize the RPC publisher.""" |
| 129 | + """Initialize the channel with a raw channel and an encoder function.""" |
127 | 130 | self._name = name |
128 | 131 | self._channel = channel |
129 | 132 | self._payload_encoder = payload_encoder |
130 | | - |
131 | | - async def publish_and_wait( |
132 | | - self, |
133 | | - request_message: RequestMessage, |
134 | | - find_response: Callable[[RoborockMessage], None], |
135 | | - future: asyncio.Future[_V], |
136 | | - ) -> _V: |
137 | | - """Helper to send a message and wait for a future to complete. |
138 | | -
|
139 | | - The find_response function will be called for each incoming message. The |
140 | | - function should check the message and call future.set_result or |
141 | | - future.set_exception as appropriate when the response is found. |
142 | | - """ |
143 | | - _LOGGER.debug( |
144 | | - "Sending command (%s, request_id=%s): %s, params=%s", |
145 | | - self._name, |
146 | | - request_message.request_id, |
147 | | - request_message.method, |
148 | | - request_message.params, |
149 | | - ) |
150 | | - message = self._payload_encoder(request_message) |
151 | | - unsub = await self._channel.subscribe(find_response) |
152 | | - try: |
153 | | - await self._channel.publish(message) |
154 | | - return await asyncio.wait_for(future, timeout=_TIMEOUT) |
155 | | - except TimeoutError as ex: |
156 | | - future.cancel() |
157 | | - raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex |
158 | | - finally: |
159 | | - unsub() |
160 | | - |
161 | | - |
162 | | -class PayloadEncodedV1RpcChannel(BaseV1RpcChannel): |
163 | | - """Protocol for V1 channels that send encoded commands.""" |
164 | | - |
165 | | - def __init__(self, publisher: RpcPublisher) -> None: |
166 | | - """Initialize the channel with a raw channel and an encoder function.""" |
167 | | - self._name = publisher._name |
168 | | - self._publisher = publisher |
| 133 | + self._decoder = decoder |
169 | 134 |
|
170 | 135 | async def _send_raw_command( |
171 | 136 | self, |
172 | 137 | method: CommandType, |
173 | 138 | *, |
174 | 139 | params: ParamsType = None, |
175 | | - ) -> ResponseData: |
| 140 | + ) -> ResponseData | bytes: |
176 | 141 | """Send a command and return a parsed response RoborockBase type.""" |
177 | 142 | request_message = RequestMessage(method, params=params) |
| 143 | + _LOGGER.debug( |
| 144 | + "Sending command (%s, request_id=%s): %s, params=%s", self._name, request_message.request_id, method, params |
| 145 | + ) |
| 146 | + message = self._payload_encoder(request_message) |
178 | 147 |
|
179 | | - future: asyncio.Future[ResponseData] = asyncio.Future() |
| 148 | + future: asyncio.Future[ResponseData | bytes] = asyncio.Future() |
180 | 149 |
|
181 | 150 | def find_response(response_message: RoborockMessage) -> None: |
182 | 151 | try: |
183 | | - decoded = decode_rpc_response(response_message) |
| 152 | + decoded = self._decoder(response_message) |
184 | 153 | except RoborockException as ex: |
185 | 154 | _LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex) |
186 | 155 | return |
187 | 156 | _LOGGER.debug("Received response (%s, request_id=%s)", self._name, decoded.request_id) |
188 | 157 | if decoded.request_id == request_message.request_id: |
189 | | - if decoded.api_error: |
| 158 | + if isinstance(decoded, ResponseMessage) and decoded.api_error: |
190 | 159 | future.set_exception(decoded.api_error) |
191 | 160 | else: |
192 | 161 | future.set_result(decoded.data) |
193 | 162 |
|
194 | | - return await self._publisher.publish_and_wait(request_message, find_response, future) |
195 | | - |
196 | | - |
197 | | -class MapRpcChannel(BaseV1RpcChannel): |
198 | | - """A V1 RPC channel that fetches and decodes map data.""" |
199 | | - |
200 | | - def __init__( |
201 | | - self, |
202 | | - publisher: RpcPublisher, |
203 | | - security_data: SecurityData, |
204 | | - ) -> None: |
205 | | - """Initialize the map RPC channel.""" |
206 | | - self._publisher = publisher |
207 | | - self._decoder = create_map_response_decoder(security_data=security_data) |
208 | | - |
209 | | - async def _send_raw_command( |
210 | | - self, |
211 | | - method: CommandType, |
212 | | - *, |
213 | | - params: ParamsType = None, |
214 | | - ) -> Any: |
215 | | - """Send a command and return a parsed response RoborockBase type.""" |
216 | | - request_message = RequestMessage(method, params=params) |
217 | | - |
218 | | - future: asyncio.Future[bytes] = asyncio.Future() |
219 | | - |
220 | | - def find_response(response_message: RoborockMessage) -> None: |
221 | | - try: |
222 | | - decoded = self._decoder(response_message) |
223 | | - except RoborockException as ex: |
224 | | - _LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex) |
225 | | - return |
226 | | - if decoded is None: |
227 | | - return |
228 | | - _LOGGER.debug("Received response (map), request_id=%s)", decoded.request_id) |
229 | | - if decoded.request_id == request_message.request_id: |
230 | | - future.set_result(decoded.data) |
231 | | - |
232 | | - return await self._publisher.publish_and_wait(request_message, find_response, future) |
233 | | - |
| 163 | + message = self._payload_encoder(request_message) |
| 164 | + unsub = await self._channel.subscribe(find_response) |
| 165 | + try: |
| 166 | + await self._channel.publish(message) |
| 167 | + return await asyncio.wait_for(future, timeout=_TIMEOUT) |
| 168 | + except TimeoutError as ex: |
| 169 | + future.cancel() |
| 170 | + raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex |
| 171 | + finally: |
| 172 | + unsub() |
234 | 173 |
|
235 | 174 | def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel: |
236 | 175 | """Create a V1 RPC channel using an MQTT channel.""" |
237 | | - publisher = RpcPublisher( |
| 176 | + return PayloadEncodedV1RpcChannel( |
238 | 177 | "mqtt", |
239 | 178 | mqtt_channel, |
240 | 179 | lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data), |
| 180 | + decode_rpc_response, |
241 | 181 | ) |
242 | | - return PayloadEncodedV1RpcChannel(publisher) |
243 | 182 |
|
244 | 183 |
|
245 | 184 | def create_local_rpc_channel(local_channel: LocalChannel) -> V1RpcChannel: |
246 | 185 | """Create a V1 RPC channel using a local channel.""" |
247 | | - publisher = RpcPublisher( |
248 | | - "local", local_channel, lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST) |
| 186 | + return PayloadEncodedV1RpcChannel( |
| 187 | + "local", |
| 188 | + local_channel, |
| 189 | + lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST), |
| 190 | + decode_rpc_response, |
249 | 191 | ) |
250 | | - return PayloadEncodedV1RpcChannel(publisher) |
251 | 192 |
|
252 | 193 |
|
253 | 194 | def create_map_rpc_channel( |
254 | 195 | mqtt_channel: MqttChannel, |
255 | 196 | security_data: SecurityData, |
256 | | -) -> MapRpcChannel: |
| 197 | +) -> V1RpcChannel: |
257 | 198 | """Create a V1 RPC channel that fetches map data. |
258 | 199 |
|
259 | 200 | This will prefer local channels when available, falling back to MQTT |
260 | 201 | channels if not. If neither is available, an exception will be raised |
261 | 202 | when trying to send a command. |
262 | 203 | """ |
263 | | - publisher = RpcPublisher( |
| 204 | + return PayloadEncodedV1RpcChannel( |
264 | 205 | "map", |
265 | 206 | mqtt_channel, |
266 | 207 | lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data), |
| 208 | + create_map_response_decoder(security_data=security_data), |
267 | 209 | ) |
268 | | - return MapRpcChannel(publisher, security_data) |
0 commit comments