Skip to content

Commit fbf1434

Browse files
allenporterLash-L
andauthored
feat: Add a trait for sending commands (#539)
* feat: Add a trait for sending commands * feat: Update cli to use new command interface * chore: fix tests * Update tests/devices/traits/v1/test_command.py Co-authored-by: Luke Lashley <conway220@gmail.com> * Update tests/devices/traits/v1/test_command.py Co-authored-by: Luke Lashley <conway220@gmail.com> * chore: fix syntax and lint errors * chore: fix lint errors --------- Co-authored-by: Luke Lashley <conway220@gmail.com>
1 parent 24a0660 commit fbf1434

File tree

4 files changed

+93
-15
lines changed

4 files changed

+93
-15
lines changed

roborock/cli.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -714,21 +714,14 @@ async def home(ctx, device_id: str, refresh: bool):
714714
@async_command
715715
async def command(ctx, cmd, device_id, params):
716716
context: RoborockContext = ctx.obj
717-
cache_data = await context.get_devices()
718-
719-
home_data = cache_data.home_data
720-
devices = home_data.get_all_devices()
721-
device = next(device for device in devices if device.duid == device_id)
722-
model = next(
723-
(product.model for product in home_data.products if device is not None and product.id == device.product_id),
724-
None,
725-
)
726-
if model is None:
727-
raise RoborockException(f"Could not find model for device {device.name}")
728-
device_info = DeviceData(device=device, model=model)
729-
mqtt_client = RoborockMqttClientV1(cache_data.user_data, device_info)
730-
await mqtt_client.send_command(cmd, json.loads(params) if params is not None else None)
731-
await mqtt_client.async_release()
717+
device_manager = await context.get_device_manager()
718+
device = await device_manager.get_device(device_id)
719+
if device.v1_properties is None:
720+
raise RoborockException(f"Device {device.name} does not support V1 protocol")
721+
command_trait: Trait = device.v1_properties.command
722+
result = await command_trait.send(cmd, json.loads(params) if params is not None else None)
723+
if result:
724+
click.echo(dump_json(result))
732725

733726

734727
@click.command()

roborock/devices/traits/v1/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from .child_lock import ChildLockTrait
1414
from .clean_summary import CleanSummaryTrait
15+
from .command import CommandTrait
1516
from .common import V1TraitMixin
1617
from .consumeable import ConsumableTrait
1718
from .device_features import DeviceFeaturesTrait
@@ -39,6 +40,7 @@
3940
"ConsumableTrait",
4041
"HomeTrait",
4142
"DeviceFeaturesTrait",
43+
"CommandTrait",
4244
"ChildLockTrait",
4345
"FlowLedStatusTrait",
4446
"LedStatusTrait",
@@ -54,6 +56,7 @@ class PropertiesApi(Trait):
5456

5557
# All v1 devices have these traits
5658
status: StatusTrait
59+
command: CommandTrait
5760
dnd: DoNotDisturbTrait
5861
clean_summary: CleanSummaryTrait
5962
sound_volume: SoundVolumeTrait
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Any
2+
3+
from roborock import RoborockCommand
4+
5+
6+
class CommandTrait:
7+
"""Trait for sending commands to Roborock devices."""
8+
9+
def __post_init__(self) -> None:
10+
"""Post-initialization to set up the RPC channel.
11+
12+
This is called automatically after the dataclass is initialized by the
13+
device setup code.
14+
"""
15+
self._rpc_channel = None
16+
17+
async def send(self, command: RoborockCommand, params: dict[str, Any] | None = None) -> Any:
18+
"""Send a command to the device."""
19+
if not self._rpc_channel:
20+
raise ValueError("Device trait in invalid state")
21+
return await self._rpc_channel.send_command(command, params=params)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Tests for the CommandTrait class."""
2+
3+
from unittest.mock import AsyncMock
4+
5+
import pytest
6+
7+
from roborock.devices.traits.v1.command import CommandTrait
8+
from roborock.exceptions import RoborockException
9+
from roborock.roborock_typing import RoborockCommand
10+
11+
12+
@pytest.fixture(name="command_trait")
13+
def command_trait_fixture() -> CommandTrait:
14+
"""Create a CommandTrait instance with a mocked RPC channel."""
15+
trait = CommandTrait()
16+
trait._rpc_channel = AsyncMock() # type: ignore[assignment]
17+
return trait
18+
19+
20+
async def test_send_command_success(command_trait: CommandTrait) -> None:
21+
"""Test successfully sending a command."""
22+
mock_rpc_channel = command_trait._rpc_channel
23+
assert mock_rpc_channel is not None
24+
mock_rpc_channel.send_command.return_value = {"result": "ok"}
25+
26+
# Call the method
27+
result = await command_trait.send(RoborockCommand.APP_START)
28+
29+
# Verify the result
30+
assert result == {"result": "ok"}
31+
32+
# Verify the RPC call was made correctly
33+
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.APP_START, params=None)
34+
35+
36+
async def test_send_command_with_params(command_trait: CommandTrait) -> None:
37+
"""Test successfully sending a command with parameters."""
38+
mock_rpc_channel = command_trait._rpc_channel
39+
assert mock_rpc_channel is not None
40+
mock_rpc_channel.send_command.return_value = {"result": "ok"}
41+
params = {"segments": [1, 2, 3]}
42+
43+
# Call the method
44+
result = await command_trait.send(RoborockCommand.APP_SEGMENT_CLEAN, params)
45+
46+
# Verify the result
47+
assert result == {"result": "ok"}
48+
49+
# Verify the RPC call was made correctly
50+
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.APP_SEGMENT_CLEAN, params=params)
51+
52+
53+
async def test_send_command_propagates_exception(command_trait: CommandTrait) -> None:
54+
"""Test that exceptions from RPC channel are propagated."""
55+
mock_rpc_channel = command_trait._rpc_channel
56+
assert mock_rpc_channel is not None
57+
mock_rpc_channel.send_command.side_effect = RoborockException("Communication error")
58+
59+
# Verify the exception is propagated
60+
with pytest.raises(RoborockException, match="Communication error"):
61+
await command_trait.send(RoborockCommand.APP_START)

0 commit comments

Comments
 (0)