Skip to content

Commit b2118dd

Browse files
committed
chore: Only allow a single trait
1 parent ca12a41 commit b2118dd

File tree

9 files changed

+50
-81
lines changed

9 files changed

+50
-81
lines changed

roborock/devices/device.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ def __init__(
3838
self,
3939
device_info: HomeDataDevice,
4040
channel: Channel,
41-
traits: list[Trait],
41+
trait: Trait,
4242
) -> None:
4343
"""Initialize the RoborockDevice.
4444
4545
The device takes ownership of the channel for communication with the device.
4646
Use `connect()` to establish the connection, which will set up the appropriate
4747
protocol channel. Use `close()` to clean up all connections.
4848
"""
49-
TraitsMixin.__init__(self, traits)
49+
TraitsMixin.__init__(self, trait)
5050
self._duid = device_info.duid
5151
self._name = device_info.name
5252
self._channel = channel

roborock/devices/device_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,23 +145,23 @@ async def create_device_manager(
145145

146146
def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
147147
channel: Channel
148-
traits: list[Trait] = []
148+
trait: Trait
149149
match device.pv:
150150
case DeviceVersion.V1:
151151
v1_channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
152152
channel = v1_channel
153-
traits.extend(v1.create_v1_traits(product, v1_channel.rpc_channel))
153+
trait = v1.create(product, v1_channel.rpc_channel)
154154
case DeviceVersion.A01:
155155
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
156156
channel = mqtt_channel
157-
traits.extend(a01.create_a01_traits(product, mqtt_channel))
157+
trait = a01.create(product, mqtt_channel)
158158
case DeviceVersion.B01:
159159
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
160160
channel = mqtt_channel
161-
traits.extend(b01.create_b01_traits(mqtt_channel))
161+
trait = b01.create(mqtt_channel)
162162
case _:
163163
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
164-
return RoborockDevice(device, channel, traits)
164+
return RoborockDevice(device, channel, trait)
165165

166166
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session, cache=cache)
167167
await manager.discover_devices()

roborock/devices/traits/a01/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ async def set_value(self, protocol: RoborockZeoProtocol, value: Any) -> dict[Rob
5050
return await send_decoded_command(self._channel, params)
5151

5252

53-
def create_a01_traits(product: HomeDataProduct, mqtt_channel: MqttChannel) -> list[Trait]:
53+
def create(product: HomeDataProduct, mqtt_channel: MqttChannel) -> DyadApi | ZeoApi:
5454
"""Create traits for A01 devices."""
5555
match product.category:
5656
case RoborockCategory.WET_DRY_VAC:
57-
return [DyadApi(mqtt_channel)]
57+
return DyadApi(mqtt_channel)
5858
case RoborockCategory.WASHING_MACHINE:
59-
return [ZeoApi(mqtt_channel)]
59+
return ZeoApi(mqtt_channel)
6060
case _:
6161
raise NotImplementedError(f"Unsupported category {product.category}")
Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,31 @@
11
"""Traits for B01 devices."""
22

3+
from roborock import RoborockB01Methods
34
from roborock.devices.b01_channel import send_decoded_command
45
from roborock.devices.mqtt_channel import MqttChannel
56
from roborock.devices.traits import Trait
6-
7-
from .props import B01PropsApi
7+
from roborock.roborock_message import RoborockB01Props
88

99
__init__ = [
1010
"create_b01_traits",
1111
"B01PropsApi",
1212
]
1313

1414

15-
def create_b01_traits(channel: MqttChannel) -> list[Trait]:
15+
class B01PropsApi(Trait):
16+
"""API for interacting with B01 devices."""
17+
18+
def __init__(self, channel: MqttChannel) -> None:
19+
"""Initialize the B01Props API."""
20+
self._channel = channel
21+
22+
async def query_values(self, props: list[RoborockB01Props]) -> None:
23+
"""Query the device for the values of the given Dyad protocols."""
24+
await send_decoded_command(
25+
self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params={"property": props}
26+
)
27+
28+
29+
def create(channel: MqttChannel) -> B01PropsApi:
1630
"""Create traits for B01 devices."""
17-
return [B01PropsApi(channel)]
31+
return B01PropsApi(channel)

roborock/devices/traits/b01/props.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

roborock/devices/traits/traits_mixin.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
class TraitsMixin:
2323
"""Mixin to provide trait accessors."""
2424

25-
v1_properties: v1.Properties | None = None
25+
v1_properties: v1.PropertiesApi | None = None
2626
"""V1 properties trait, if supported."""
2727

2828
dyad: a01.DyadApi | None = None
@@ -31,20 +31,21 @@ class TraitsMixin:
3131
zeo: a01.ZeoApi | None = None
3232
"""Zeo API, if supported."""
3333

34-
b01_properties: b01.B01PropsApi | None = None
34+
b01_props_api: b01.B01PropsApi | None = None
3535
"""B01 properties trait, if supported."""
3636

37-
def __init__(self, traits: list[Trait]) -> None:
38-
"""Initialize the TraitsMixin with the given traits list.
37+
def __init__(self, trait: Trait) -> None:
38+
"""Initialize the TraitsMixin with the given trait.
3939
4040
This will populate the appropriate trait attributes based on the types
4141
of the traits provided.
4242
"""
43-
trait_map: dict[type[Trait], Trait] = {type(item): item for item in traits}
43+
# trait_map: dict[type[Trait], Trait] = {type(item): item for item in traits}
4444
for item in fields(self):
4545
trait_type = _get_trait_type(item)
46-
if (trait := trait_map.get(trait_type, None)) is not None:
46+
if trait_type == type(trait):
4747
setattr(self, item.name, trait)
48+
break
4849

4950

5051
def _get_trait_type(item) -> type[Trait]:

roborock/devices/traits/v1/__init__.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
from .properties import CleanSummaryTrait, DoNotDisturbTrait, SoundVolumeTrait, StatusTrait
1010

1111
__all__ = [
12-
"create_v1_traits",
13-
"Properties",
12+
"create",
13+
"PropertiesApi",
1414
"properties",
1515
]
1616

1717

1818
@dataclass
19-
class Properties(Trait):
19+
class PropertiesApi(Trait):
2020
"""Common properties for V1 devices.
2121
2222
This class holds all the traits that are common across all V1 devices.
@@ -34,16 +34,14 @@ def __init__(self, product: HomeDataProduct, rpc_channel: V1RpcChannel) -> None:
3434
"""Initialize the V1TraitProps with None values."""
3535
self.status = StatusTrait(product)
3636

37+
# Create traits and set the RPC channel
3738
for item in fields(self):
3839
if (trait := getattr(self, item.name, None)) is None:
3940
trait = item.type()
4041
setattr(self, item.name, trait)
4142
trait._rpc_channel = rpc_channel
4243

4344

44-
def create_v1_traits(product: HomeDataProduct, rpc_channel: V1RpcChannel) -> list[Trait]:
45+
def create(product: HomeDataProduct, rpc_channel: V1RpcChannel) -> PropertiesApi:
4546
"""Create traits for V1 devices."""
46-
return [
47-
Properties(product, rpc_channel)
48-
# Add optional traits here as needed in the future
49-
]
47+
return PropertiesApi(product, rpc_channel)

tests/devices/test_v1_device.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
from roborock.containers import HomeData, S7MaxVStatus, UserData
1111
from roborock.devices.device import RoborockDevice
12-
from roborock.devices.traits import Trait
13-
from roborock.devices.traits.v1 import create_v1_traits
12+
from roborock.devices.traits import v1
1413
from roborock.devices.traits.v1.common import V1TraitMixin
1514
from roborock.devices.v1_rpc_channel import decode_rpc_response
1615
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
@@ -37,21 +36,15 @@ def rpc_channel_fixture() -> AsyncMock:
3736

3837

3938
@pytest.fixture(autouse=True, name="device")
40-
def device_fixture(channel: AsyncMock, traits: list[Trait]) -> RoborockDevice:
39+
def device_fixture(channel: AsyncMock, rpc_channel: AsyncMock) -> RoborockDevice:
4140
"""Fixture to set up the device for tests."""
4241
return RoborockDevice(
4342
device_info=HOME_DATA.devices[0],
4443
channel=channel,
45-
traits=traits,
44+
trait=v1.create(HOME_DATA.products[0], rpc_channel),
4645
)
4746

4847

49-
@pytest.fixture(autouse=True, name="traits")
50-
def traits_fixture(rpc_channel: AsyncMock) -> list[Trait]:
51-
"""Fixture to set up the V1 API for tests."""
52-
return create_v1_traits(HOME_DATA.products[0], rpc_channel)
53-
54-
5548
async def test_device_connection(device: RoborockDevice, channel: AsyncMock) -> None:
5649
"""Test the Device connection setup."""
5750

@@ -91,7 +84,7 @@ def setup_rpc_channel_fixture(rpc_channel: AsyncMock, payload: pathlib.Path) ->
9184

9285

9386
@pytest.mark.parametrize(
94-
("payload", "trait_method"),
87+
("payload", "property_method"),
9588
[
9689
(TESTDATA / "get_status.json", lambda x: x.status),
9790
(TESTDATA / "get_dnd.json", lambda x: x.dnd),
@@ -103,11 +96,11 @@ async def test_device_trait_command_parsing(
10396
device: RoborockDevice,
10497
setup_rpc_channel: AsyncMock,
10598
snapshot: SnapshotAssertion,
106-
trait_method: Callable[..., V1TraitMixin],
99+
property_method: Callable[..., V1TraitMixin],
107100
payload: str,
108101
) -> None:
109102
"""Test the device trait command."""
110-
trait = trait_method(device.v1_properties)
103+
trait = property_method(device.v1_properties)
111104
assert trait
112105
assert isinstance(trait, V1TraitMixin)
113106
await trait.refresh()

tests/devices/traits/v1/fixtures.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
from roborock.containers import HomeData, S7MaxVStatus, UserData
88
from roborock.devices.device import RoborockDevice
9-
from roborock.devices.traits import Trait
10-
from roborock.devices.traits.v1 import create_v1_traits
9+
from roborock.devices.traits import v1
1110

1211
from .... import mock_data
1312

@@ -29,16 +28,10 @@ def rpc_channel_fixture() -> AsyncMock:
2928

3029

3130
@pytest.fixture(autouse=True, name="device")
32-
def device_fixture(channel: AsyncMock, traits: list[Trait]) -> RoborockDevice:
31+
def device_fixture(channel: AsyncMock, mock_rpc_channel: AsyncMock) -> RoborockDevice:
3332
"""Fixture to set up the device for tests."""
3433
return RoborockDevice(
3534
device_info=HOME_DATA.devices[0],
3635
channel=channel,
37-
traits=traits,
36+
trait=v1.create(HOME_DATA.products[0], mock_rpc_channel),
3837
)
39-
40-
41-
@pytest.fixture(autouse=True, name="traits")
42-
def traits_fixture(mock_rpc_channel: AsyncMock) -> list[Trait]:
43-
"""Fixture to set up the V1 API for tests."""
44-
return create_v1_traits(HOME_DATA.products[0], mock_rpc_channel)

0 commit comments

Comments
 (0)