Skip to content

Commit 4a0b709

Browse files
feat: adding type cast for send_command
1 parent 56cba8c commit 4a0b709

File tree

5 files changed

+66
-72
lines changed

5 files changed

+66
-72
lines changed

roborock/api.py

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import struct
1414
import time
1515
from random import randint
16-
from typing import Any, Callable, Coroutine, Optional, Type
16+
from typing import Any, Callable, Coroutine, Optional, Type, TypeVar
1717

1818
import aiohttp
1919

@@ -30,6 +30,7 @@
3030
ModelStatus,
3131
MultiMapsList,
3232
NetworkInfo,
33+
RoborockBase,
3334
RoomMapping,
3435
S7MaxVStatus,
3536
SmartWashParams,
@@ -63,6 +64,7 @@
6364
RoborockCommand.GET_MAP_V1,
6465
RoborockCommand.GET_MULTI_MAP,
6566
]
67+
RT = TypeVar("RT", bound=RoborockBase)
6668

6769

6870
def md5hex(message: str) -> str:
@@ -228,36 +230,32 @@ def _get_payload(
228230
)
229231
return request_id, timestamp, payload
230232

231-
async def send_command(self, method: RoborockCommand, params: Optional[list | dict] = None):
233+
async def send_command(
234+
self,
235+
method: RoborockCommand,
236+
params: Optional[list | dict] = None,
237+
return_type: Optional[Type[RT]] = None,
238+
) -> RT:
232239
raise NotImplementedError
233240

234241
@fallback_cache
235242
async def get_status(self) -> Status | None:
236-
status = await self.send_command(RoborockCommand.GET_STATUS)
237-
if isinstance(status, dict):
238-
_cls: Type[Status] = ModelStatus.get(
239-
self.device_info.model, S7MaxVStatus
240-
) # Default to S7 MAXV if we don't have the data
241-
return _cls.from_dict(status)
242-
return None
243+
_cls: Type[Status] = ModelStatus.get(
244+
self.device_info.model, S7MaxVStatus
245+
) # Default to S7 MAXV if we don't have the data
246+
return await self.send_command(RoborockCommand.GET_STATUS, return_type=_cls)
243247

244248
@fallback_cache
245249
async def get_dnd_timer(self) -> DnDTimer | None:
246-
dnd_timer = await self.send_command(RoborockCommand.GET_DND_TIMER)
247-
if isinstance(dnd_timer, dict):
248-
return DnDTimer.from_dict(dnd_timer)
249-
return None
250+
return await self.send_command(RoborockCommand.GET_DND_TIMER, return_type=DnDTimer)
250251

251252
@fallback_cache
252253
async def get_valley_electricity_timer(self) -> ValleyElectricityTimer | None:
253-
valley_electricity_timer = await self.send_command(RoborockCommand.GET_VALLEY_ELECTRICITY_TIMER)
254-
if isinstance(valley_electricity_timer, dict):
255-
return ValleyElectricityTimer.from_dict(valley_electricity_timer)
256-
return None
254+
return await self.send_command(RoborockCommand.GET_VALLEY_ELECTRICITY_TIMER, return_type=ValleyElectricityTimer)
257255

258256
@fallback_cache
259257
async def get_clean_summary(self) -> CleanSummary | None:
260-
clean_summary = await self.send_command(RoborockCommand.GET_CLEAN_SUMMARY)
258+
clean_summary: dict | list | int = await self.send_command(RoborockCommand.GET_CLEAN_SUMMARY)
261259
if isinstance(clean_summary, dict):
262260
return CleanSummary.from_dict(clean_summary)
263261
elif isinstance(clean_summary, list):
@@ -274,38 +272,23 @@ async def get_clean_summary(self) -> CleanSummary | None:
274272

275273
@fallback_cache
276274
async def get_clean_record(self, record_id: int) -> CleanRecord | None:
277-
clean_record = await self.send_command(RoborockCommand.GET_CLEAN_RECORD, [record_id])
278-
if isinstance(clean_record, dict):
279-
return CleanRecord.from_dict(clean_record)
280-
return None
275+
return await self.send_command(RoborockCommand.GET_CLEAN_RECORD, [record_id], return_type=CleanRecord)
281276

282277
@fallback_cache
283278
async def get_consumable(self) -> Consumable | None:
284-
consumable = await self.send_command(RoborockCommand.GET_CONSUMABLE)
285-
if isinstance(consumable, dict):
286-
return Consumable.from_dict(consumable)
287-
return None
279+
return await self.send_command(RoborockCommand.GET_CONSUMABLE, return_type=Consumable)
288280

289281
@fallback_cache
290282
async def get_wash_towel_mode(self) -> WashTowelMode | None:
291-
washing_mode = await self.send_command(RoborockCommand.GET_WASH_TOWEL_MODE)
292-
if isinstance(washing_mode, dict):
293-
return WashTowelMode.from_dict(washing_mode)
294-
return None
283+
return await self.send_command(RoborockCommand.GET_WASH_TOWEL_MODE, return_type=WashTowelMode)
295284

296285
@fallback_cache
297286
async def get_dust_collection_mode(self) -> DustCollectionMode | None:
298-
dust_collection = await self.send_command(RoborockCommand.GET_DUST_COLLECTION_MODE)
299-
if isinstance(dust_collection, dict):
300-
return DustCollectionMode.from_dict(dust_collection)
301-
return None
287+
return await self.send_command(RoborockCommand.GET_DUST_COLLECTION_MODE, return_type=DustCollectionMode)
302288

303289
@fallback_cache
304290
async def get_smart_wash_params(self) -> SmartWashParams | None:
305-
mop_wash_mode = await self.send_command(RoborockCommand.GET_SMART_WASH_PARAMS)
306-
if isinstance(mop_wash_mode, dict):
307-
return SmartWashParams.from_dict(mop_wash_mode)
308-
return None
291+
return await self.send_command(RoborockCommand.GET_SMART_WASH_PARAMS, return_type=SmartWashParams)
309292

310293
@fallback_cache
311294
async def get_dock_summary(self, dock_type: RoborockDockTypeCode) -> DockSummary | None:
@@ -359,22 +342,16 @@ async def get_prop(self) -> DeviceProp | None:
359342

360343
@fallback_cache
361344
async def get_multi_maps_list(self) -> MultiMapsList | None:
362-
multi_maps_list = await self.send_command(RoborockCommand.GET_MULTI_MAPS_LIST)
363-
if isinstance(multi_maps_list, dict):
364-
return MultiMapsList.from_dict(multi_maps_list)
365-
return None
345+
return await self.send_command(RoborockCommand.GET_MULTI_MAPS_LIST, return_type=MultiMapsList)
366346

367347
@fallback_cache
368348
async def get_networking(self) -> NetworkInfo | None:
369-
networking_info = await self.send_command(RoborockCommand.GET_NETWORK_INFO)
370-
if isinstance(networking_info, dict):
371-
return NetworkInfo.from_dict(networking_info)
372-
return None
349+
return await self.send_command(RoborockCommand.GET_NETWORK_INFO, return_type=NetworkInfo)
373350

374351
@fallback_cache
375352
async def get_room_mapping(self) -> list[RoomMapping] | None:
376353
"""Gets the mapping from segment id -> iot id. Only works on local api."""
377-
mapping = await self.send_command(RoborockCommand.GET_ROOM_MAPPING)
354+
mapping: list = await self.send_command(RoborockCommand.GET_ROOM_MAPPING)
378355
if isinstance(mapping, list):
379356
return [
380357
RoomMapping(segment_id=segment_id, iot_id=iot_id) # type: ignore
@@ -385,10 +362,7 @@ async def get_room_mapping(self) -> list[RoomMapping] | None:
385362
@fallback_cache
386363
async def get_child_lock_status(self) -> ChildLockStatus | None:
387364
"""Gets current child lock status."""
388-
child_lock_status = await self.send_command(RoborockCommand.GET_CHILD_LOCK_STATUS)
389-
if isinstance(child_lock_status, dict):
390-
return ChildLockStatus.from_dict(child_lock_status)
391-
return None
365+
return await self.send_command(RoborockCommand.GET_CHILD_LOCK_STATUS, return_type=ChildLockStatus)
392366

393367
@fallback_cache
394368
async def get_sound_volume(self) -> int | None:

roborock/cloud_api.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import threading
66
import uuid
77
from asyncio import Lock
8-
from typing import Optional
8+
from typing import Optional, Type
99
from urllib.parse import urlparse
1010

1111
import paho.mqtt.client as mqtt
1212

13-
from .api import COMMANDS_SECURED, KEEPALIVE, RoborockClient, md5hex
13+
from .api import COMMANDS_SECURED, KEEPALIVE, RT, RoborockClient, md5hex
1414
from .containers import DeviceData, UserData
1515
from .exceptions import CommandVacuumError, RoborockException, VacuumError
1616
from .protocol import MessageParser, Utils
@@ -149,7 +149,12 @@ def _send_msg_raw(self, msg: bytes) -> None:
149149
if info.rc != mqtt.MQTT_ERR_SUCCESS:
150150
raise RoborockException(f"Failed to publish ({mqtt.error_string(info.rc)})")
151151

152-
async def send_command(self, method: RoborockCommand, params: Optional[list | dict] = None):
152+
async def send_command(
153+
self,
154+
method: RoborockCommand,
155+
params: Optional[list | dict] = None,
156+
return_type: Optional[Type[RT]] = None,
157+
):
153158
await self.validate_connection()
154159
request_id, timestamp, payload = super()._get_payload(method, params, True)
155160
_LOGGER.debug(f"id={request_id} Requesting method {method} with {params}")
@@ -166,11 +171,9 @@ async def send_command(self, method: RoborockCommand, params: Optional[list | di
166171
_LOGGER.debug(f"id={request_id} Response from {method}: {len(response)} bytes")
167172
else:
168173
_LOGGER.debug(f"id={request_id} Response from {method}: {response}")
174+
if return_type:
175+
return return_type.from_dict(response)
169176
return response
170177

171178
async def get_map_v1(self):
172-
try:
173-
return await self.send_command(RoborockCommand.GET_MAP_V1)
174-
except RoborockException as e:
175-
_LOGGER.error(e)
176-
return None
179+
return await self.send_command(RoborockCommand.GET_MAP_V1)

roborock/containers.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ class RoborockBase:
8080

8181
@classmethod
8282
def from_dict(cls, data: dict[str, Any]):
83-
ignore_keys = cls._ignore_keys
84-
return from_dict(cls, decamelize_obj(data, ignore_keys), config=Config(cast=[Enum]))
83+
if isinstance(data, dict):
84+
ignore_keys = cls._ignore_keys
85+
return from_dict(cls, decamelize_obj(data, ignore_keys), config=Config(cast=[Enum]))
8586

8687
def as_dict(self) -> dict:
8788
return asdict(
@@ -348,7 +349,9 @@ class DnDTimer(RoborockBase):
348349

349350
def __post_init__(self) -> None:
350351
self.start_time = (
351-
time(hour=self.start_hour, minute=self.start_minute) if self.start_hour and self.start_minute else None
352+
time(hour=self.start_hour, minute=self.start_minute)
353+
if self.start_hour is not None and self.start_minute is not None
354+
else None
352355
)
353356
self.end_time = (
354357
time(hour=self.end_hour, minute=self.end_minute)
@@ -369,7 +372,9 @@ class ValleyElectricityTimer(RoborockBase):
369372

370373
def __post_init__(self) -> None:
371374
self.start_time = (
372-
time(hour=self.start_hour, minute=self.start_minute) if self.start_hour and self.start_minute else None
375+
time(hour=self.start_hour, minute=self.start_minute)
376+
if self.start_hour is not None and self.start_minute is not None
377+
else None
373378
)
374379
self.end_time = (
375380
time(hour=self.end_hour, minute=self.end_minute)

roborock/local_api.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
import asyncio
44
import logging
55
from asyncio import Lock, TimerHandle, Transport
6-
from typing import Optional
6+
from typing import Optional, Type
77

88
import async_timeout
99

1010
from . import DeviceData
11-
from .api import COMMANDS_SECURED, QUEUE_TIMEOUT, RoborockClient
12-
from .exceptions import CommandVacuumError, RoborockConnectionException, RoborockException
11+
from .api import COMMANDS_SECURED, QUEUE_TIMEOUT, RT, RoborockClient
12+
from .exceptions import (
13+
CommandVacuumError,
14+
RoborockConnectionException,
15+
RoborockException,
16+
)
1317
from .protocol import MessageParser
1418
from .roborock_message import RoborockMessage, RoborockMessageProtocol
1519
from .roborock_typing import RoborockCommand
@@ -123,9 +127,17 @@ async def ping(self):
123127
)
124128
)
125129

126-
async def send_command(self, method: RoborockCommand, params: Optional[list | dict] = None):
130+
async def send_command(
131+
self,
132+
method: RoborockCommand,
133+
params: Optional[list | dict] = None,
134+
return_type: Optional[Type[RT]] = None,
135+
):
127136
roborock_message = self.build_roborock_message(method, params)
128-
return (await self.send_message(roborock_message))[0]
137+
response = (await self.send_message(roborock_message))[0]
138+
if return_type:
139+
return return_type.from_dict(response)
140+
return response
129141

130142
async def async_local_response(self, roborock_message: RoborockMessage):
131143
method = roborock_message.get_method()

tests/test_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from roborock.api import PreparedRequest, RoborockApiClient
1414
from roborock.cloud_api import RoborockMqttClient
15-
from roborock.containers import DeviceData, S7MaxVStatus
15+
from roborock.containers import DeviceData, DustCollectionMode, S7MaxVStatus, SmartWashParams, WashTowelMode
1616
from tests.mock_data import BASE_URL_REQUEST, GET_CODE_RESPONSE, HOME_DATA_RAW, STATUS, USER_DATA
1717

1818

@@ -78,7 +78,7 @@ async def test_get_dust_collection_mode():
7878
device_info = DeviceData(device=home_data.devices[0], model=home_data.products[0].model)
7979
rmc = RoborockMqttClient(UserData.from_dict(USER_DATA), device_info)
8080
with patch("roborock.cloud_api.RoborockMqttClient.send_command") as command:
81-
command.return_value = {"mode": 1}
81+
command.return_value = DustCollectionMode.from_dict({"mode": 1})
8282
dust = await rmc.get_dust_collection_mode()
8383
assert dust is not None
8484
assert dust.mode == RoborockDockDustCollectionModeCode.light
@@ -90,7 +90,7 @@ async def test_get_mop_wash_mode():
9090
device_info = DeviceData(device=home_data.devices[0], model=home_data.products[0].model)
9191
rmc = RoborockMqttClient(UserData.from_dict(USER_DATA), device_info)
9292
with patch("roborock.cloud_api.RoborockMqttClient.send_command") as command:
93-
command.return_value = {"smart_wash": 0, "wash_interval": 1500}
93+
command.return_value = SmartWashParams.from_dict({"smart_wash": 0, "wash_interval": 1500})
9494
mop_wash = await rmc.get_smart_wash_params()
9595
assert mop_wash is not None
9696
assert mop_wash.smart_wash == 0
@@ -103,7 +103,7 @@ async def test_get_washing_mode():
103103
device_info = DeviceData(device=home_data.devices[0], model=home_data.products[0].model)
104104
rmc = RoborockMqttClient(UserData.from_dict(USER_DATA), device_info)
105105
with patch("roborock.cloud_api.RoborockMqttClient.send_command") as command:
106-
command.return_value = {"wash_mode": 2}
106+
command.return_value = WashTowelMode.from_dict({"wash_mode": 2})
107107
washing_mode = await rmc.get_wash_towel_mode()
108108
assert washing_mode is not None
109109
assert washing_mode.wash_mode == RoborockDockWashTowelModeCode.deep

0 commit comments

Comments
 (0)