|
5 | 5 | import logging |
6 | 6 | import threading |
7 | 7 | import uuid |
8 | | -from asyncio import Lock |
9 | | -from typing import Optional |
| 8 | +from asyncio import Lock, Task |
| 9 | +from typing import Any, Optional |
10 | 10 | from urllib.parse import urlparse |
11 | 11 |
|
12 | 12 | import paho.mqtt.client as mqtt |
@@ -112,40 +112,52 @@ def sync_start_loop(self) -> None: |
112 | 112 | self._logger.info("Starting mqtt loop") |
113 | 113 | super().loop_start() |
114 | 114 |
|
115 | | - def sync_disconnect(self) -> bool: |
116 | | - rc = mqtt.MQTT_ERR_AGAIN |
| 115 | + def sync_disconnect(self) -> tuple[bool, Task[tuple[Any, VacuumError | None]] | None]: |
| 116 | + if not self.is_connected(): |
| 117 | + return False, None |
| 118 | + |
| 119 | + self._logger.info("Disconnecting from mqtt") |
| 120 | + disconnected_future = asyncio.ensure_future(self._async_response(DISCONNECT_REQUEST_ID)) |
| 121 | + rc = super().disconnect() |
| 122 | + |
| 123 | + if rc == mqtt.MQTT_ERR_NO_CONN: |
| 124 | + disconnected_future.cancel() |
| 125 | + return False, None |
| 126 | + |
| 127 | + if rc != mqtt.MQTT_ERR_SUCCESS: |
| 128 | + disconnected_future.cancel() |
| 129 | + raise RoborockException(f"Failed to disconnect ({mqtt.error_string(rc)})") |
| 130 | + |
| 131 | + return True, disconnected_future |
| 132 | + |
| 133 | + def sync_connect(self) -> tuple[bool, Task[tuple[Any, VacuumError | None]] | None]: |
117 | 134 | if self.is_connected(): |
118 | | - self._logger.info("Disconnecting from mqtt") |
119 | | - rc = super().disconnect() |
120 | | - if rc not in [mqtt.MQTT_ERR_SUCCESS, mqtt.MQTT_ERR_NO_CONN]: |
121 | | - raise RoborockException(f"Failed to disconnect ({mqtt.error_string(rc)})") |
122 | | - return rc == mqtt.MQTT_ERR_SUCCESS |
123 | | - |
124 | | - def sync_connect(self) -> bool: |
125 | | - should_connect = not self.is_connected() |
126 | | - if should_connect: |
127 | | - if self._mqtt_port is None or self._mqtt_host is None: |
128 | | - raise RoborockException("Mqtt information was not entered. Cannot connect.") |
129 | | - self._logger.info("Connecting to mqtt") |
130 | | - super().connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE) |
| 135 | + self.sync_start_loop() |
| 136 | + return False, None |
| 137 | + |
| 138 | + if self._mqtt_port is None or self._mqtt_host is None: |
| 139 | + raise RoborockException("Mqtt information was not entered. Cannot connect.") |
| 140 | + |
| 141 | + self._logger.info("Connecting to mqtt") |
| 142 | + connected_future = asyncio.ensure_future(self._async_response(CONNECT_REQUEST_ID)) |
| 143 | + super().connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE) |
| 144 | + |
131 | 145 | self.sync_start_loop() |
132 | | - return should_connect |
| 146 | + return True, connected_future |
133 | 147 |
|
134 | 148 | async def async_disconnect(self) -> None: |
135 | 149 | async with self._mutex: |
136 | | - async_response = asyncio.ensure_future(self._async_response(DISCONNECT_REQUEST_ID)) |
137 | | - disconnecting = self.sync_disconnect() |
138 | | - if disconnecting: |
139 | | - (_, err) = await async_response |
| 150 | + (disconnecting, disconnected_future) = self.sync_disconnect() |
| 151 | + if disconnecting and disconnected_future: |
| 152 | + (_, err) = await disconnected_future |
140 | 153 | if err: |
141 | 154 | raise RoborockException(err) from err |
142 | 155 |
|
143 | 156 | async def async_connect(self) -> None: |
144 | 157 | async with self._mutex: |
145 | | - async_response = asyncio.ensure_future(self._async_response(CONNECT_REQUEST_ID)) |
146 | | - connecting = self.sync_connect() |
147 | | - if connecting: |
148 | | - (_, err) = await async_response |
| 158 | + (connecting, connected_future) = self.sync_connect() |
| 159 | + if connecting and connected_future: |
| 160 | + (_, err) = await connected_future |
149 | 161 | if err: |
150 | 162 | raise RoborockException(err) from err |
151 | 163 |
|
|
0 commit comments