Skip to content

Commit 3f9cb37

Browse files
committed
chore: update container parsing using native typing and dataclass
1 parent 636268d commit 3f9cb37

File tree

2 files changed

+136
-79
lines changed

2 files changed

+136
-79
lines changed

roborock/containers.py

Lines changed: 54 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from __future__ import annotations
2-
1+
import dataclasses
32
import datetime
43
import json
54
import logging
65
import re
6+
import types
77
from dataclasses import asdict, dataclass, field
88
from datetime import timezone
99
from enum import Enum
1010
from functools import cached_property
11-
from typing import Any, NamedTuple, get_args, get_origin
11+
from typing import Any, NamedTuple, get_args, get_origin, Self
1212

1313
from .code_mappings import (
1414
SHORT_MODEL_TO_ENUM,
@@ -95,105 +95,73 @@
9595
_LOGGER = logging.getLogger(__name__)
9696

9797

98-
def camelize(s: str):
98+
def _camelize(s: str):
9999
first, *others = s.split("_")
100100
if len(others) == 0:
101101
return s
102102
return "".join([first.lower(), *map(str.title, others)])
103103

104104

105-
def decamelize(s: str):
105+
def _decamelize(s: str):
106106
return re.sub("([A-Z]+)", "_\\1", s).lower()
107107

108108

109-
def decamelize_obj(d: dict | list, ignore_keys: list[str]):
110-
if isinstance(d, RoborockBase):
111-
d = d.as_dict()
112-
if isinstance(d, list):
113-
return [decamelize_obj(i, ignore_keys) if isinstance(i, dict | list) else i for i in d]
114-
return {
115-
(decamelize(a) if a not in ignore_keys else a): decamelize_obj(b, ignore_keys)
116-
if isinstance(b, dict | list)
117-
else b
118-
for a, b in d.items()
119-
}
120-
121-
122109
@dataclass
123110
class RoborockBase:
124111
_ignore_keys = [] # type: ignore
125-
is_cached = False
126112

127113
@staticmethod
128-
def convert_to_class_obj(type, value):
129-
try:
130-
class_type = eval(type)
131-
if get_origin(class_type) is list:
132-
return_list = []
133-
cls_type = get_args(class_type)[0]
134-
for obj in value:
135-
if issubclass(cls_type, RoborockBase):
136-
return_list.append(cls_type.from_dict(obj))
137-
elif cls_type in {str, int, float}:
138-
return_list.append(cls_type(obj))
139-
else:
140-
return_list.append(cls_type(**obj))
141-
return return_list
142-
if issubclass(class_type, RoborockBase):
143-
converted_value = class_type.from_dict(value)
144-
else:
145-
converted_value = class_type(value)
146-
return converted_value
147-
except NameError as err:
148-
_LOGGER.exception(err)
149-
except ValueError as err:
150-
_LOGGER.exception(err)
151-
except Exception as err:
152-
_LOGGER.exception(err)
153-
raise Exception("Fail")
114+
def _convert_to_class_obj(class_type: type, value):
115+
if get_origin(class_type) is list:
116+
sub_type = get_args(class_type)[0]
117+
return [RoborockBase._convert_to_class_obj(sub_type, obj) for obj in value]
118+
if get_origin(class_type) is dict:
119+
_, value_type = get_args(class_type) # assume keys are only basic types
120+
return {k: RoborockBase._convert_to_class_obj(value_type, v) for k, v in value.items()}
121+
if issubclass(class_type, RoborockBase):
122+
return class_type.from_dict(value)
123+
if class_type is Any:
124+
return value
125+
return class_type(value) # type: ignore[call-arg]
154126

155127
@classmethod
156128
def from_dict(cls, data: dict[str, Any]):
157-
if isinstance(data, dict):
158-
ignore_keys = cls._ignore_keys
159-
data = decamelize_obj(data, ignore_keys)
160-
cls_annotations: dict[str, str] = {}
161-
for base in reversed(cls.__mro__):
162-
cls_annotations.update(getattr(base, "__annotations__", {}))
163-
remove_keys = []
164-
for key, value in data.items():
165-
if key not in cls_annotations:
166-
remove_keys.append(key)
167-
continue
168-
if value == "None" or value is None:
169-
data[key] = None
170-
continue
171-
field_type: str = cls_annotations[key]
172-
if "|" in field_type:
173-
# It's a union
174-
types = field_type.split("|")
175-
for type in types:
176-
if "None" in type or "Any" in type:
177-
continue
178-
try:
179-
data[key] = RoborockBase.convert_to_class_obj(type, value)
180-
break
181-
except Exception:
182-
...
183-
else:
129+
"""Create an instance of the class from a dictionary."""
130+
if not isinstance(data, dict):
131+
return None
132+
field_types = {field.name: field.type for field in dataclasses.fields(cls)}
133+
result: Self = {}
134+
for key, value in data.items():
135+
key = _decamelize(key)
136+
if (field_type := field_types.get(key)) is None:
137+
continue
138+
if value == "None" or value is None:
139+
result[key] = None
140+
continue
141+
if isinstance(field_type, types.UnionType):
142+
for subtype in get_args(field_type):
143+
if subtype is types.NoneType:
144+
continue
184145
try:
185-
data[key] = RoborockBase.convert_to_class_obj(field_type, value)
146+
result[key] = RoborockBase._convert_to_class_obj(subtype, value)
147+
break
186148
except Exception:
187-
...
188-
for key in remove_keys:
189-
del data[key]
190-
return cls(**data)
149+
_LOGGER.exception(f"Failed to convert {key} with value {value} to type {subtype}")
150+
continue
151+
else:
152+
try:
153+
result[key] = RoborockBase._convert_to_class_obj(field_type, value)
154+
except Exception:
155+
_LOGGER.exception(f"Failed to convert {key} with value {value} to type {field_type}")
156+
continue
157+
158+
return cls(**result)
191159

192160
def as_dict(self) -> dict:
193161
return asdict(
194162
self,
195163
dict_factory=lambda _fields: {
196-
camelize(key): value.value if isinstance(value, Enum) else value
164+
_camelize(key): value.value if isinstance(value, Enum) else value
197165
for (key, value) in _fields
198166
if value is not None
199167
},
@@ -891,3 +859,11 @@ class DyadSndState(RoborockBase):
891859
@dataclass
892860
class DyadOtaNfo(RoborockBase):
893861
mqttOtaData: dict
862+
863+
864+
@dataclass
865+
class SimpleObject(RoborockBase):
866+
"""Simple object for testing serialization."""
867+
868+
name: str | None = None
869+
value: int | None = None

tests/test_containers.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from roborock import CleanRecord, CleanSummary, Consumable, DnDTimer, HomeData, S7MaxVStatus, UserData
1+
"""Test cases for the containers module."""
2+
3+
from dataclasses import dataclass
4+
from typing import Any
5+
6+
from roborock import CleanRecord, CleanSummary, Consumable, DnDTimer, HomeData, S7MaxVStatus, SimpleObject, UserData
27
from roborock.code_mappings import (
38
RoborockCategory,
49
RoborockDockErrorCode,
@@ -9,6 +14,7 @@
914
RoborockMopModeS7,
1015
RoborockStateCode,
1116
)
17+
from roborock.containers import RoborockBase
1218

1319
from .mock_data import (
1420
CLEAN_RECORD,
@@ -23,6 +29,80 @@
2329
)
2430

2531

32+
@dataclass
33+
class HomeDataRoom(RoborockBase):
34+
id: int
35+
name: str
36+
37+
38+
@dataclass
39+
class ComplexObject(RoborockBase):
40+
"""Complex object for testing serialization."""
41+
42+
simple: SimpleObject | None = None
43+
items: list[str] | None = None
44+
value: int | None = None
45+
nested_dict: dict[str, SimpleObject] | None = None
46+
nested_list: list[SimpleObject] | None = None
47+
any: Any | None = None
48+
49+
50+
def test_simple_object() -> None:
51+
"""Test serialization and deserialization of a simple object."""
52+
53+
obj = SimpleObject(name="Test", value=42)
54+
serialized = obj.as_dict()
55+
assert serialized == {"name": "Test", "value": 42}
56+
deserialized = SimpleObject.from_dict(serialized)
57+
assert deserialized.name == "Test"
58+
assert deserialized.value == 42
59+
60+
61+
def test_complex_object() -> None:
62+
"""Test serialization and deserialization of a complex object."""
63+
simple = SimpleObject(name="Nested", value=100)
64+
obj = ComplexObject(
65+
simple=simple,
66+
items=["item1", "item2"],
67+
value=200,
68+
nested_dict={
69+
"nested1": SimpleObject(name="Nested1", value=1),
70+
"nested2": SimpleObject(name="Nested2", value=2),
71+
},
72+
nested_list=[SimpleObject(name="Nested3", value=3), SimpleObject(name="Nested4", value=4)],
73+
any="This can be anything",
74+
)
75+
serialized = obj.as_dict()
76+
assert serialized == {
77+
"simple": {"name": "Nested", "value": 100},
78+
"items": ["item1", "item2"],
79+
"value": 200,
80+
"nestedDict": {
81+
"nested1": {"name": "Nested1", "value": 1},
82+
"nested2": {"name": "Nested2", "value": 2},
83+
},
84+
"nestedList": [
85+
{"name": "Nested3", "value": 3},
86+
{"name": "Nested4", "value": 4},
87+
],
88+
"any": "This can be anything",
89+
}
90+
deserialized = ComplexObject.from_dict(serialized)
91+
assert deserialized.simple.name == "Nested"
92+
assert deserialized.simple.value == 100
93+
assert deserialized.items == ["item1", "item2"]
94+
assert deserialized.value == 200
95+
assert deserialized.nested_dict == {
96+
"nested1": SimpleObject(name="Nested1", value=1),
97+
"nested2": SimpleObject(name="Nested2", value=2),
98+
}
99+
assert deserialized.nested_list == [
100+
SimpleObject(name="Nested3", value=3),
101+
SimpleObject(name="Nested4", value=4),
102+
]
103+
assert deserialized.any == "This can be anything"
104+
105+
26106
def test_user_data():
27107
ud = UserData.from_dict(USER_DATA)
28108
assert ud.uid == 123456
@@ -184,6 +264,7 @@ def test_clean_summary():
184264
assert cs.square_meter_clean_area == 1159.2
185265
assert cs.clean_count == 31
186266
assert cs.dust_collection_count == 25
267+
assert cs.records
187268
assert len(cs.records) == 2
188269
assert cs.records[1] == 1672458041
189270

0 commit comments

Comments
 (0)