Skip to content

Commit 5f17c2b

Browse files
committed
Simplify future usage within the api clients
1 parent 6777dd7 commit 5f17c2b

File tree

10 files changed

+86
-70
lines changed

10 files changed

+86
-70
lines changed

roborock/api.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import secrets
99
import time
10-
from collections.abc import Callable, Coroutine
10+
from collections.abc import Callable
1111
from typing import Any
1212

1313
from .containers import (
@@ -16,7 +16,6 @@
1616
from .exceptions import (
1717
RoborockTimeout,
1818
UnknownMethodError,
19-
VacuumError,
2019
)
2120
from .roborock_future import RoborockFuture
2221
from .roborock_message import (
@@ -98,20 +97,18 @@ async def validate_connection(self) -> None:
9897
await self.async_disconnect()
9998
await self.async_connect()
10099

101-
async def _wait_response(self, request_id: int, queue: RoborockFuture) -> tuple[Any, VacuumError | None]:
100+
async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any:
102101
try:
103-
(response, err) = await queue.async_get(self.queue_timeout)
102+
response = await queue.async_get(self.queue_timeout)
104103
if response == "unknown_method":
105104
raise UnknownMethodError("Unknown method")
106-
return response, err
105+
return response
107106
except (asyncio.TimeoutError, asyncio.CancelledError):
108107
raise RoborockTimeout(f"id={request_id} Timeout after {self.queue_timeout} seconds") from None
109108
finally:
110109
self._waiting_queue.pop(request_id, None)
111110

112-
def _async_response(
113-
self, request_id: int, protocol_id: int = 0
114-
) -> Coroutine[Any, Any, tuple[Any, VacuumError | None]]:
111+
def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
115112
queue = RoborockFuture(protocol_id)
116113
if request_id in self._waiting_queue:
117114
new_id = get_next_int(10000, 32767)
@@ -121,7 +118,7 @@ def _async_response(
121118
)
122119
request_id = new_id
123120
self._waiting_queue[request_id] = queue
124-
return self._wait_response(request_id, queue)
121+
return asyncio.ensure_future(self._wait_response(request_id, queue))
125122

126123
async def send_message(self, roborock_message: RoborockMessage):
127124
raise NotImplementedError

roborock/cloud_api.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import base64
54
import logging
65
import threading
76
import typing
87
import uuid
9-
from asyncio import Lock, Task
8+
from asyncio import Lock
109
from typing import Any
1110
from urllib.parse import urlparse
1211

@@ -65,7 +64,7 @@ def on_connect(self, *args, **kwargs):
6564
message = f"Failed to connect ({mqtt.error_string(rc)})"
6665
self._logger.error(message)
6766
if connection_queue:
68-
connection_queue.resolve((None, VacuumError(message)))
67+
connection_queue.set_exception(VacuumError(message))
6968
return
7069
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
7170
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
@@ -74,11 +73,11 @@ def on_connect(self, *args, **kwargs):
7473
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
7574
self._logger.error(message)
7675
if connection_queue:
77-
connection_queue.resolve((None, VacuumError(message)))
76+
connection_queue.set_exception(VacuumError(message))
7877
return
7978
self._logger.info(f"Subscribed to topic {topic}")
8079
if connection_queue:
81-
connection_queue.resolve((True, None))
80+
connection_queue.set_result(True)
8281

8382
def on_message(self, *args, **kwargs):
8483
client, __, msg = args
@@ -97,7 +96,7 @@ def on_disconnect(self, *args, **kwargs):
9796
self.update_client_id()
9897
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
9998
if connection_queue:
100-
connection_queue.resolve((True, None))
99+
connection_queue.set_result(True)
101100
except Exception as ex:
102101
self._logger.exception(ex)
103102

@@ -115,53 +114,53 @@ def sync_start_loop(self) -> None:
115114
self._logger.info("Starting mqtt loop")
116115
super().loop_start()
117116

118-
def sync_disconnect(self) -> tuple[bool, Task[tuple[Any, VacuumError | None]] | None]:
117+
def sync_disconnect(self) -> Any:
119118
if not self.is_connected():
120-
return False, None
119+
return None
121120

122121
self._logger.info("Disconnecting from mqtt")
123-
disconnected_future = asyncio.ensure_future(self._async_response(DISCONNECT_REQUEST_ID))
122+
disconnected_future = self._async_response(DISCONNECT_REQUEST_ID)
124123
rc = super().disconnect()
125124

126125
if rc == mqtt.MQTT_ERR_NO_CONN:
127126
disconnected_future.cancel()
128-
return False, None
127+
return None
129128

130129
if rc != mqtt.MQTT_ERR_SUCCESS:
131130
disconnected_future.cancel()
132131
raise RoborockException(f"Failed to disconnect ({mqtt.error_string(rc)})")
133132

134-
return True, disconnected_future
133+
return disconnected_future
135134

136-
def sync_connect(self) -> tuple[bool, Task[tuple[Any, VacuumError | None]] | None]:
135+
def sync_connect(self) -> Any:
137136
if self.is_connected():
138137
self.sync_start_loop()
139-
return False, None
138+
return None
140139

141140
if self._mqtt_port is None or self._mqtt_host is None:
142141
raise RoborockException("Mqtt information was not entered. Cannot connect.")
143142

144143
self._logger.debug("Connecting to mqtt")
145-
connected_future = asyncio.ensure_future(self._async_response(CONNECT_REQUEST_ID))
144+
connected_future = self._async_response(CONNECT_REQUEST_ID)
146145
super().connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
147146

148147
self.sync_start_loop()
149-
return True, connected_future
148+
return connected_future
150149

151150
async def async_disconnect(self) -> None:
152151
async with self._mutex:
153-
(disconnecting, disconnected_future) = self.sync_disconnect()
154-
if disconnecting and disconnected_future:
155-
(_, err) = await disconnected_future
156-
if err:
152+
if disconnected_future := self.sync_disconnect():
153+
try:
154+
await disconnected_future
155+
except VacuumError as err:
157156
raise RoborockException(err) from err
158157

159158
async def async_connect(self) -> None:
160159
async with self._mutex:
161-
(connecting, connected_future) = self.sync_connect()
162-
if connecting and connected_future:
163-
(_, err) = await connected_future
164-
if err:
160+
if connected_future := self.sync_connect():
161+
try:
162+
await connected_future
163+
except VacuumError as err:
165164
raise RoborockException(err) from err
166165

167166
def _send_msg_raw(self, msg: bytes) -> None:

roborock/roborock_future.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@ def __init__(self, protocol: int):
1414
self.fut: Future = Future()
1515
self.loop = self.fut.get_loop()
1616

17-
def _resolve(self, item: tuple[Any, VacuumError | None]) -> None:
17+
def _set_result(self, item: Any) -> None:
1818
if not self.fut.cancelled():
1919
self.fut.set_result(item)
2020

21-
def resolve(self, item: tuple[Any, VacuumError | None]) -> None:
22-
self.loop.call_soon_threadsafe(self._resolve, item)
21+
def set_result(self, item: Any) -> None:
22+
self.loop.call_soon_threadsafe(self._set_result, item)
23+
24+
def _set_exception(self, exc: VacuumError) -> None:
25+
if not self.fut.cancelled():
26+
self.fut.set_exception(exc)
27+
28+
def set_exception(self, exc: VacuumError) -> None:
29+
self.loop.call_soon_threadsafe(self._set_exception, exc)
2330

2431
async def async_get(self, timeout: float | int) -> tuple[Any, VacuumError | None]:
2532
try:

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -377,20 +377,17 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
377377
if queue and queue.protocol == protocol:
378378
error = data_point_response.get("error")
379379
if error:
380-
queue.resolve(
381-
(
382-
None,
383-
VacuumError(
384-
error.get("code"),
385-
error.get("message"),
386-
),
387-
)
380+
queue.set_exception(
381+
VacuumError(
382+
error.get("code"),
383+
error.get("message"),
384+
),
388385
)
389386
else:
390387
result = data_point_response.get("result")
391388
if isinstance(result, list) and len(result) == 1:
392389
result = result[0]
393-
queue.resolve((result, None))
390+
queue.set_result(result)
394391
else:
395392
try:
396393
data_protocol = RoborockDataProtocol(int(data_point_number))
@@ -442,11 +439,11 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
442439
if queue:
443440
if isinstance(decompressed, list):
444441
decompressed = decompressed[0]
445-
queue.resolve((decompressed, None))
442+
queue.set_result(decompressed)
446443
else:
447444
queue = self._waiting_queue.get(data.seq)
448445
if queue:
449-
queue.resolve((data.payload, None))
446+
queue.set_result(data.payload)
450447
except Exception as ex:
451448
self._logger.exception(ex)
452449

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import asyncio
2-
31
from roborock.local_api import RoborockLocalClient
42

53
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
4+
from ..exceptions import VacuumError
65
from ..protocol import MessageParser
76
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
87
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
@@ -52,16 +51,21 @@ async def send_message(self, roborock_message: RoborockMessage):
5251
if method:
5352
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
5453
# Send the command to the Roborock device
55-
async_response = asyncio.ensure_future(self._async_response(request_id, response_protocol))
54+
async_response = self._async_response(request_id, response_protocol)
5655
self._send_msg_raw(msg)
57-
(response, err) = await async_response
58-
self._diagnostic_data[method if method is not None else "unknown"] = {
56+
diagnostic_key = method if method is not None else "unknown"
57+
try:
58+
response = await async_response
59+
except VacuumError as err:
60+
self._diagnostic_data[diagnostic_key] = {
61+
"params": roborock_message.get_params(),
62+
"error": err,
63+
}
64+
raise CommandVacuumError(method, err) from err
65+
self._diagnostic_data[diagnostic_key] = {
5966
"params": roborock_message.get_params(),
6067
"response": response,
61-
"error": err,
6268
}
63-
if err:
64-
raise CommandVacuumError(method, err) from err
6569
if roborock_message.protocol == RoborockMessageProtocol.GENERAL_REQUEST:
6670
self._logger.debug(f"id={request_id} Response from method {roborock_message.get_method()}: {response}")
6771
if response == "retry":

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import base64
32

43
import paho.mqtt.client as mqtt
@@ -10,7 +9,7 @@
109
from roborock.cloud_api import RoborockMqttClient
1110

1211
from ..containers import DeviceData, UserData
13-
from ..exceptions import CommandVacuumError, RoborockException
12+
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
1413
from ..protocol import MessageParser, Utils
1514
from ..roborock_message import (
1615
RoborockMessage,
@@ -49,16 +48,21 @@ async def send_message(self, roborock_message: RoborockMessage):
4948
local_key = self.device_info.device.local_key
5049
msg = MessageParser.build(roborock_message, local_key, False)
5150
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
52-
async_response = asyncio.ensure_future(self._async_response(request_id, response_protocol))
51+
async_response = self._async_response(request_id, response_protocol)
5352
self._send_msg_raw(msg)
54-
(response, err) = await async_response
55-
self._diagnostic_data[method if method is not None else "unknown"] = {
53+
diagnostic_key = method if method is not None else "unknown"
54+
try:
55+
response = await async_response
56+
except VacuumError as err:
57+
self._diagnostic_data[diagnostic_key] = {
58+
"params": roborock_message.get_params(),
59+
"error": err,
60+
}
61+
raise CommandVacuumError(method, err) from err
62+
self._diagnostic_data[diagnostic_key] = {
5663
"params": roborock_message.get_params(),
5764
"response": response,
58-
"error": err,
5965
}
60-
if err:
61-
raise CommandVacuumError(method, err) from err
6266
if response_protocol == RoborockMessageProtocol.MAP_RESPONSE:
6367
self._logger.debug(f"id={request_id} Response from {method}: {len(response)} bytes")
6468
else:

roborock/version_a01_apis/roborock_client_a01.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
135135
converted_response = entries[data_point_protocol].post_process_fn(data_point)
136136
queue = self._waiting_queue.get(int(data_point_number))
137137
if queue and queue.protocol == protocol:
138-
queue.resolve((converted_response, None))
138+
queue.set_result(converted_response)
139139

140140
async def update_values(
141141
self, dyad_data_protocols: list[RoborockDyadDataProtocol | RoborockZeoProtocol]

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def send_message(self, roborock_message: RoborockMessage):
4343
futures = []
4444
if "10000" in payload["dps"]:
4545
for dps in json.loads(payload["dps"]["10000"]):
46-
futures.append(asyncio.ensure_future(self._async_response(dps, response_protocol)))
46+
futures.append(self._async_response(dps, response_protocol))
4747
self._send_msg_raw(m)
4848
responses = await asyncio.gather(*futures, return_exceptions=True)
4949
dps_responses: dict[int, typing.Any] = {}
@@ -54,7 +54,7 @@ async def send_message(self, roborock_message: RoborockMessage):
5454
self._logger.warning("Timed out get req for %s after %s s", dps, self.queue_timeout)
5555
dps_responses[dps] = None
5656
else:
57-
dps_responses[dps] = response[0]
57+
dps_responses[dps] = response
5858
return dps_responses
5959

6060
async def update_values(

tests/test_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def test_can_create_mqtt_roborock():
3434
async def test_sync_connect(mqtt_client):
3535
with patch("paho.mqtt.client.Client.connect", return_value=mqtt.MQTT_ERR_SUCCESS):
3636
with patch("paho.mqtt.client.Client.loop_start", return_value=mqtt.MQTT_ERR_SUCCESS):
37-
connecting, connected_future = mqtt_client.sync_connect()
38-
assert connecting is True
37+
connected_future = mqtt_client.sync_connect()
3938
assert connected_future is not None
4039

4140
connected_future.cancel()

tests/test_queue.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from roborock.exceptions import VacuumError
56
from roborock.roborock_future import RoborockFuture
67

78

@@ -10,10 +11,18 @@ def test_can_create():
1011

1112

1213
@pytest.mark.asyncio
13-
async def test_put():
14+
async def test_set_result():
1415
rq = RoborockFuture(1)
15-
rq.resolve(("test", None))
16-
assert await rq.async_get(1) == ("test", None)
16+
rq.set_result("test")
17+
assert await rq.async_get(1) == "test"
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_set_exception():
22+
rq = RoborockFuture(1)
23+
rq.set_exception(VacuumError("test"))
24+
with pytest.raises(VacuumError):
25+
assert await rq.async_get(1)
1726

1827

1928
@pytest.mark.asyncio

0 commit comments

Comments
 (0)