Skip to content

Commit 622d342

Browse files
committed
feat: Add diagnostics library for tracking stats/counters
Add a generic library that can be used to track counters, or elapsed time (e.g. how long it takes on average to connect to mqtt). This is a copy of https://github.com/allenporter/python-google-nest-sdm/blob/main/google_nest_sdm/diagnostics.py This is an initial pass to add a few initial example metrics for MQTT, but we can add more as we need fine grained details in diagnostics.
1 parent cd7ef7c commit 622d342

File tree

7 files changed

+293
-18
lines changed

7 files changed

+293
-18
lines changed

roborock/devices/device_manager.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import asyncio
44
import enum
55
import logging
6-
from collections.abc import Callable
6+
from collections.abc import Callable, Mapping
77
from dataclasses import dataclass
8+
from typing import Any
89

910
import aiohttp
1011

@@ -16,6 +17,7 @@
1617
)
1718
from roborock.devices.device import DeviceReadyCallback, RoborockDevice
1819
from roborock.exceptions import RoborockException
20+
from roborock.diagnostics import Diagnostics
1921
from roborock.map.map_parser import MapParserConfig
2022
from roborock.mqtt.roborock_session import create_lazy_mqtt_session
2123
from roborock.mqtt.session import MqttSession
@@ -58,6 +60,7 @@ def __init__(
5860
device_creator: DeviceCreator,
5961
mqtt_session: MqttSession,
6062
cache: Cache,
63+
diagnostics: Diagnostics,
6164
) -> None:
6265
"""Initialize the DeviceManager with user data and optional cache storage.
6366
@@ -68,12 +71,15 @@ def __init__(
6871
self._device_creator = device_creator
6972
self._devices: dict[str, RoborockDevice] = {}
7073
self._mqtt_session = mqtt_session
74+
self._diagnostics = diagnostics
7175

7276
async def discover_devices(self, prefer_cache: bool = True) -> list[RoborockDevice]:
7377
"""Discover all devices for the logged-in user."""
78+
self._diagnostics.increment("discover_devices")
7479
cache_data = await self._cache.get()
7580
if not cache_data.home_data or not prefer_cache:
7681
_LOGGER.debug("Fetching home data (prefer_cache=%s)", prefer_cache)
82+
self._diagnostics.increment("fetch_home_data")
7783
try:
7884
cache_data.home_data = await self._web_api.get_home_data()
7985
except RoborockException as ex:
@@ -116,6 +122,10 @@ async def close(self) -> None:
116122
tasks.append(self._mqtt_session.close())
117123
await asyncio.gather(*tasks)
118124

125+
def diagnostic_data(self) -> Mapping[str, Any]:
126+
"""Return diagnostics information about the device manager."""
127+
return self._diagnostics.as_dict()
128+
119129

120130
@dataclass
121131
class UserParams:
@@ -182,7 +192,10 @@ async def create_device_manager(
182192
web_api = create_web_api_wrapper(user_params, session=session, cache=cache)
183193
user_data = user_params.user_data
184194

195+
diagnostics = Diagnostics()
196+
185197
mqtt_params = create_mqtt_params(user_data.rriot)
198+
mqtt_params.diagnostics = diagnostics.subkey("mqtt_session")
186199
mqtt_session = await create_lazy_mqtt_session(mqtt_params)
187200

188201
def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
@@ -226,6 +239,6 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat
226239
dev.add_ready_callback(ready_callback)
227240
return dev
228241

229-
manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache)
242+
manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache, diagnostics=diagnostics)
230243
await manager.discover_devices()
231244
return manager

roborock/diagnostics.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Diagnostics for debugging.
2+
3+
A Diagnostics object can be used to track counts and latencies of various
4+
operations within a module. This can be useful for debugging performance issues
5+
or understanding usage patterns.
6+
7+
This is an internal facing module and is not intended for public use. Diagnostics
8+
data is collected and exposed to clients via higher level APIs like the
9+
DeviceManager.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import time
15+
from collections import Counter
16+
from collections.abc import Generator, Mapping
17+
from contextlib import contextmanager
18+
from typing import Any
19+
20+
21+
class Diagnostics:
22+
"""A class that holdes diagnostics information for a module.
23+
24+
You can use this class to hold counter or for recording timing information
25+
that can be exported as a dictionary for debugging purposes.
26+
"""
27+
28+
def __init__(self) -> None:
29+
"""Initialize Diagnostics."""
30+
self._counter: Counter = Counter()
31+
self._subkeys: dict[str, Diagnostics] = {}
32+
33+
def increment(self, key: str, count: int = 1) -> None:
34+
"""Increment a counter for the specified key/event."""
35+
self._counter.update(Counter({key: count}))
36+
37+
def elapsed(self, key_prefix: str, elapsed_ms: int = 1) -> None:
38+
"""Track a latency event for the specified key/event prefix."""
39+
self.increment(f"{key_prefix}_count", 1)
40+
self.increment(f"{key_prefix}_sum", elapsed_ms)
41+
42+
def as_dict(self) -> Mapping[str, Any]:
43+
"""Return diagnostics as a debug dictionary."""
44+
data: dict[str, Any] = {k: self._counter[k] for k in self._counter}
45+
for k, d in self._subkeys.items():
46+
v = d.as_dict()
47+
if not v:
48+
continue
49+
data[k] = v
50+
return data
51+
52+
def subkey(self, key: str) -> Diagnostics:
53+
"""Return sub-Diagnositics object with the specified subkey.
54+
55+
This will create a new Diagnostics object if one does not already exist
56+
for the specified subkey. Stats from the sub-Diagnostics will be included
57+
in the parent Diagnostics when exported as a dictionary.
58+
59+
Args:
60+
key: The subkey for the diagnostics.
61+
62+
Returns:
63+
The Diagnostics object for the specified subkey.
64+
"""
65+
if key not in self._subkeys:
66+
self._subkeys[key] = Diagnostics()
67+
return self._subkeys[key]
68+
69+
@contextmanager
70+
def timer(self, key_prefix: str) -> Generator[None, None, None]:
71+
"""A context manager that records the timing of operations as a diagnostic."""
72+
start = time.perf_counter()
73+
try:
74+
yield
75+
finally:
76+
end = time.perf_counter()
77+
ms = int((end - start) * 1000)
78+
self.elapsed(key_prefix, ms)
79+
80+
def reset(self) -> None:
81+
"""Clear all diagnostics, for testing."""
82+
self._counter = Counter()
83+
for d in self._subkeys.values():
84+
d.reset()

roborock/mqtt/roborock_session.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from aiomqtt import MqttCodeError, MqttError, TLSParameters
1919

2020
from roborock.callbacks import CallbackMap
21+
from roborock.diagnostics import Diagnostics
2122

2223
from .health_manager import HealthManager
2324
from .session import MqttParams, MqttSession, MqttSessionException, MqttSessionUnauthorized
@@ -76,6 +77,7 @@ def __init__(
7677
self._connection_task: asyncio.Task[None] | None = None
7778
self._topic_idle_timeout = topic_idle_timeout
7879
self._idle_timers: dict[str, asyncio.Task[None]] = {}
80+
self._diagnostics = params.diagnostics
7981
self._health_manager = HealthManager(self.restart)
8082

8183
@property
@@ -96,24 +98,30 @@ async def start(self) -> None:
9698
handle the failure and retry if desired itself. Once connected,
9799
the session will retry connecting in the background.
98100
"""
101+
self._diagnostics.increment("start_attempt")
99102
start_future: asyncio.Future[None] = asyncio.Future()
100103
loop = asyncio.get_event_loop()
101104
self._reconnect_task = loop.create_task(self._run_reconnect_loop(start_future))
102105
try:
103106
await start_future
104107
except MqttCodeError as err:
108+
self._diagnostics.increment(f"start_failure:{err.rc}")
105109
if err.rc == MqttReasonCode.RC_ERROR_UNAUTHORIZED:
106110
raise MqttSessionUnauthorized(f"Authorization error starting MQTT session: {err}") from err
107111
raise MqttSessionException(f"Error starting MQTT session: {err}") from err
108112
except MqttError as err:
113+
self._diagnostics.increment("start_failure:unknown")
109114
raise MqttSessionException(f"Error starting MQTT session: {err}") from err
110115
except Exception as err:
116+
self._diagnostics.increment("start_failure:uncaught")
111117
raise MqttSessionException(f"Unexpected error starting session: {err}") from err
112118
else:
119+
self._diagnostics.increment("start_success")
113120
_LOGGER.debug("MQTT session started successfully")
114121

115122
async def close(self) -> None:
116123
"""Cancels the MQTT loop and shutdown the client library."""
124+
self._diagnostics.increment("close")
117125
self._stop = True
118126
tasks = [task for task in [self._connection_task, self._reconnect_task, *self._idle_timers.values()] if task]
119127
self._connection_task = None
@@ -136,6 +144,7 @@ async def restart(self) -> None:
136144
the reconnect loop. This is a no-op if there is no active connection.
137145
"""
138146
_LOGGER.info("Forcing MQTT session restart")
147+
self._diagnostics.increment("restart")
139148
if self._connection_task:
140149
self._connection_task.cancel()
141150
else:
@@ -144,6 +153,7 @@ async def restart(self) -> None:
144153
async def _run_reconnect_loop(self, start_future: asyncio.Future[None] | None) -> None:
145154
"""Run the MQTT loop."""
146155
_LOGGER.info("Starting MQTT session")
156+
self._diagnostics.increment("start_loop")
147157
while True:
148158
try:
149159
self._connection_task = asyncio.create_task(self._run_connection(start_future))
@@ -164,6 +174,7 @@ async def _run_reconnect_loop(self, start_future: asyncio.Future[None] | None) -
164174
_LOGGER.debug("MQTT session closed, stopping retry loop")
165175
return
166176
_LOGGER.info("MQTT session disconnected, retrying in %s seconds", self._backoff.total_seconds())
177+
self._diagnostics.increment("reconnect_wait")
167178
await asyncio.sleep(self._backoff.total_seconds())
168179
self._backoff = min(self._backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL)
169180

@@ -175,17 +186,19 @@ async def _run_connection(self, start_future: asyncio.Future[None] | None) -> No
175186
is lost, this method will exit.
176187
"""
177188
try:
178-
async with self._mqtt_client(self._params) as client:
179-
self._backoff = MIN_BACKOFF_INTERVAL
180-
self._healthy = True
181-
_LOGGER.info("MQTT Session connected.")
182-
if start_future and not start_future.done():
183-
start_future.set_result(None)
184-
185-
_LOGGER.debug("Processing MQTT messages")
186-
async for message in client.messages:
187-
_LOGGER.debug("Received message: %s", message)
188-
self._listeners(message.topic.value, message.payload)
189+
with self._diagnostics.timer("connection"):
190+
async with self._mqtt_client(self._params) as client:
191+
self._backoff = MIN_BACKOFF_INTERVAL
192+
self._healthy = True
193+
_LOGGER.info("MQTT Session connected.")
194+
if start_future and not start_future.done():
195+
start_future.set_result(None)
196+
197+
_LOGGER.debug("Processing MQTT messages")
198+
async for message in client.messages:
199+
_LOGGER.debug("Received message: %s", message)
200+
with self._diagnostics.timer("dispatch_message"):
201+
self._listeners(message.topic.value, message.payload)
189202
except MqttError as err:
190203
if start_future and not start_future.done():
191204
_LOGGER.info("MQTT error starting session: %s", err)
@@ -251,6 +264,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
251264

252265
# If there is an idle timer for this topic, cancel it (reuse subscription)
253266
if idle_timer := self._idle_timers.pop(topic, None):
267+
self._diagnostics.increment("unsubscribe_idle_cancel")
254268
idle_timer.cancel()
255269
_LOGGER.debug("Cancelled idle timer for topic %s (reused subscription)", topic)
256270

@@ -262,13 +276,15 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
262276
if self._client:
263277
_LOGGER.debug("Establishing subscription to topic %s", topic)
264278
try:
265-
await self._client.subscribe(topic)
279+
with self._diagnostics.timer("subscribe"):
280+
await self._client.subscribe(topic)
266281
except MqttError as err:
267282
# Clean up the callback if subscription fails
268283
unsub()
269284
self._client_subscribed_topics.discard(topic)
270285
raise MqttSessionException(f"Error subscribing to topic: {err}") from err
271286
else:
287+
self._diagnostics.increment("subscribe_pending")
272288
_LOGGER.debug("Client not connected, will establish subscription later")
273289

274290
def schedule_unsubscribe() -> None:
@@ -301,10 +317,10 @@ async def idle_unsubscribe():
301317
self._idle_timers[topic] = task
302318

303319
def delayed_unsub():
320+
self._diagnostics.increment("unsubscribe")
304321
unsub() # Remove the callback from CallbackMap
305322
# If no more callbacks for this topic, start idle timer
306323
if not self._listeners.get_callbacks(topic):
307-
_LOGGER.debug("Unsubscribing topic %s, starting idle timer", topic)
308324
schedule_unsubscribe()
309325
else:
310326
_LOGGER.debug("Unsubscribing topic %s, still have active callbacks", topic)
@@ -320,7 +336,8 @@ async def publish(self, topic: str, message: bytes) -> None:
320336
raise MqttSessionException("Could not publish message, MQTT client not connected")
321337
client = self._client
322338
try:
323-
await client.publish(topic, message)
339+
with self._diagnostics.timer("publish"):
340+
await client.publish(topic, message)
324341
except MqttError as err:
325342
raise MqttSessionException(f"Error publishing message: {err}") from err
326343

@@ -333,11 +350,12 @@ class LazyMqttSession(MqttSession):
333350
is made.
334351
"""
335352

336-
def __init__(self, session: RoborockMqttSession) -> None:
353+
def __init__(self, session: RoborockMqttSession, diagnostics: Diagnostics) -> None:
337354
"""Initialize the lazy session with an existing session."""
338355
self._lock = asyncio.Lock()
339356
self._started = False
340357
self._session = session
358+
self._diagnostics = diagnostics
341359

342360
@property
343361
def connected(self) -> bool:
@@ -353,6 +371,7 @@ async def _maybe_start(self) -> None:
353371
"""Start the MQTT session if not already started."""
354372
async with self._lock:
355373
if not self._started:
374+
self._diagnostics.increment("start")
356375
await self._session.start()
357376
self._started = True
358377

@@ -403,4 +422,4 @@ async def create_lazy_mqtt_session(params: MqttParams) -> MqttSession:
403422
This function is a factory for creating an MQTT session that will
404423
only connect when the first attempt to subscribe or publish is made.
405424
"""
406-
return LazyMqttSession(RoborockMqttSession(params))
425+
return LazyMqttSession(RoborockMqttSession(params), diagnostics=params.diagnostics.subkey("lazy_mqtt"))

roborock/mqtt/session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Callable
55
from dataclasses import dataclass
66

7+
from roborock.diagnostics import Diagnostics
78
from roborock.exceptions import RoborockException
89
from roborock.mqtt.health_manager import HealthManager
910

@@ -32,6 +33,14 @@ class MqttParams:
3233
timeout: float = DEFAULT_TIMEOUT
3334
"""Timeout for communications with the broker in seconds."""
3435

36+
diagnostics: Diagnostics = Diagnostics()
37+
"""Diagnostics object for tracking MQTT session stats.
38+
39+
This defaults to a new Diagnostics object, but the common case is the
40+
caller will provide their own (e.g., from a DeviceManager) so that the
41+
shared MQTT session diagnostics are included in the overall diagnostics.
42+
"""
43+
3544

3645
class MqttSession(ABC):
3746
"""An MQTT session for sending and receiving messages."""

tests/devices/test_device_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,17 @@ async def test_start_connect_unexpected_error(home_data: HomeData, channel_failu
348348
"""Test that some unexpected errors from start_connect are propagated."""
349349
with pytest.raises(Exception, match="Unexpected error"):
350350
await create_device_manager(USER_PARAMS)
351+
352+
353+
async def test_diagnostics_collection(home_data: HomeData) -> None:
354+
"""Test that diagnostics are collected correctly in the DeviceManager."""
355+
device_manager = await create_device_manager(USER_PARAMS)
356+
devices = await device_manager.get_devices()
357+
assert len(devices) == 1
358+
359+
diagnostics = device_manager.diagnostic_data()
360+
assert diagnostics is not None
361+
assert diagnostics.get("discover_devices") == 1
362+
assert diagnostics.get("fetch_home_data") == 1
363+
364+
await device_manager.close()

0 commit comments

Comments
 (0)