Skip to content

Commit 5add0da

Browse files
authored
chore: remove level of inheritance in mqtt client (#286)
1 parent 1f5a9ec commit 5add0da

File tree

1 file changed

+53
-29
lines changed

1 file changed

+53
-29
lines changed

roborock/cloud_api.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,44 @@
2121
DISCONNECT_REQUEST_ID = 1
2222

2323

24-
class RoborockMqttClient(RoborockClient, mqtt.Client, ABC):
24+
class _Mqtt(mqtt.Client):
25+
"""Internal MQTT client.
26+
27+
This is a subclass of the Paho MQTT client that adds some additional functionality
28+
for error cases where things get stuck.
29+
"""
30+
2531
_thread: threading.Thread
2632
_client_id: str
2733

34+
def __init__(self) -> None:
35+
"""Initialize the MQTT client."""
36+
super().__init__(protocol=mqtt.MQTTv5)
37+
self.reset_client_id()
38+
39+
def reset_client_id(self):
40+
"""Generate a new client id to make a new session when reconnecting."""
41+
self._client_id = mqtt.base62(uuid.uuid4().int, padding=22)
42+
43+
def maybe_restart_loop(self) -> None:
44+
"""Ensure that the MQTT loop is running in case it previously exited."""
45+
if not self._thread or not self._thread.is_alive():
46+
if self._thread:
47+
_LOGGER.info("Stopping mqtt loop")
48+
super().loop_stop()
49+
_LOGGER.info("Starting mqtt loop")
50+
super().loop_start()
51+
52+
53+
class RoborockMqttClient(RoborockClient, ABC):
54+
"""Roborock MQTT client base class."""
55+
2856
def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout: int = 10) -> None:
57+
"""Initialize the Roborock MQTT client."""
2958
rriot = user_data.rriot
3059
if rriot is None:
3160
raise RoborockException("Got no rriot data from user_data")
3261
RoborockClient.__init__(self, device_info, queue_timeout)
33-
mqtt.Client.__init__(self, protocol=mqtt.MQTTv5)
3462
self._mqtt_user = rriot.u
3563
self._hashed_user = md5hex(self._mqtt_user + ":" + rriot.k)[2:10]
3664
url = urlparse(rriot.r.m)
@@ -39,16 +67,21 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
3967
self._mqtt_host = str(url.hostname)
4068
self._mqtt_port = url.port
4169
self._mqtt_ssl = url.scheme == "ssl"
70+
71+
self._mqtt_client = _Mqtt()
72+
self._mqtt_client.on_connect = self._mqtt_on_connect
73+
self._mqtt_client.on_message = self._mqtt_on_message
74+
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect
4275
if self._mqtt_ssl:
43-
super().tls_set()
76+
self._mqtt_client.tls_set()
77+
4478
self._mqtt_password = rriot.s
4579
self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:]
46-
super().username_pw_set(self._hashed_user, self._hashed_password)
80+
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
4781
self._waiting_queue: dict[int, RoborockFuture] = {}
4882
self._mutex = Lock()
49-
self.update_client_id()
5083

51-
def on_connect(self, *args, **kwargs):
84+
def _mqtt_on_connect(self, *args, **kwargs):
5285
_, __, ___, rc, ____ = args
5386
connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID)
5487
if rc != mqtt.MQTT_ERR_SUCCESS:
@@ -59,7 +92,7 @@ def on_connect(self, *args, **kwargs):
5992
return
6093
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
6194
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
62-
(result, mid) = self.subscribe(topic)
95+
(result, mid) = self._mqtt_client.subscribe(topic)
6396
if result != 0:
6497
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
6598
self._logger.error(message)
@@ -70,48 +103,38 @@ def on_connect(self, *args, **kwargs):
70103
if connection_queue:
71104
connection_queue.set_result(True)
72105

73-
def on_message(self, *args, **kwargs):
106+
def _mqtt_on_message(self, *args, **kwargs):
74107
client, __, msg = args
75108
try:
76109
messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key)
77110
super().on_message_received(messages)
78111
except Exception as ex:
79112
self._logger.exception(ex)
80113

81-
def on_disconnect(self, *args, **kwargs):
114+
def _mqtt_on_disconnect(self, *args, **kwargs):
82115
_, __, rc, ___ = args
83116
try:
84117
exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None
85118
super().on_connection_lost(exc)
86119
if rc == mqtt.MQTT_ERR_PROTOCOL:
87-
self.update_client_id()
120+
self._mqtt_client.reset_client_id()
88121
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
89122
if connection_queue:
90123
connection_queue.set_result(True)
91124
except Exception as ex:
92125
self._logger.exception(ex)
93126

94-
def update_client_id(self):
95-
self._client_id = mqtt.base62(uuid.uuid4().int, padding=22)
96-
97-
def sync_stop_loop(self) -> None:
98-
if self._thread:
99-
self._logger.info("Stopping mqtt loop")
100-
super().loop_stop()
101-
102-
def sync_start_loop(self) -> None:
103-
if not self._thread or not self._thread.is_alive():
104-
self.sync_stop_loop()
105-
self._logger.info("Starting mqtt loop")
106-
super().loop_start()
127+
def is_connected(self) -> bool:
128+
"""Check if the mqtt client is connected."""
129+
return self._mqtt_client.is_connected()
107130

108131
def sync_disconnect(self) -> Any:
109132
if not self.is_connected():
110133
return None
111134

112135
self._logger.info("Disconnecting from mqtt")
113136
disconnected_future = self._async_response(DISCONNECT_REQUEST_ID)
114-
rc = super().disconnect()
137+
rc = self._mqtt_client.disconnect()
115138

116139
if rc == mqtt.MQTT_ERR_NO_CONN:
117140
disconnected_future.cancel()
@@ -125,17 +148,16 @@ def sync_disconnect(self) -> Any:
125148

126149
def sync_connect(self) -> Any:
127150
if self.is_connected():
128-
self.sync_start_loop()
151+
self._mqtt_client.maybe_restart_loop()
129152
return None
130153

131154
if self._mqtt_port is None or self._mqtt_host is None:
132155
raise RoborockException("Mqtt information was not entered. Cannot connect.")
133156

134157
self._logger.debug("Connecting to mqtt")
135158
connected_future = self._async_response(CONNECT_REQUEST_ID)
136-
super().connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
137-
138-
self.sync_start_loop()
159+
self._mqtt_client.connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
160+
self._mqtt_client.maybe_restart_loop()
139161
return connected_future
140162

141163
async def async_disconnect(self) -> None:
@@ -155,6 +177,8 @@ async def async_connect(self) -> None:
155177
raise RoborockException(err) from err
156178

157179
def _send_msg_raw(self, msg: bytes) -> None:
158-
info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg)
180+
info = self._mqtt_client.publish(
181+
f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg
182+
)
159183
if info.rc != mqtt.MQTT_ERR_SUCCESS:
160184
raise RoborockException(f"Failed to publish ({mqtt.error_string(info.rc)})")

0 commit comments

Comments
 (0)