|
| 1 | +import asyncio |
1 | 2 | import io |
2 | 3 | import logging |
3 | 4 | import re |
| 5 | +from asyncio import Protocol |
4 | 6 | from collections.abc import Callable, Generator |
5 | 7 | from queue import Queue |
6 | 8 | from typing import Any |
|
11 | 13 |
|
12 | 14 | from roborock import HomeData, UserData |
13 | 15 | from roborock.containers import DeviceData |
| 16 | +from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 |
14 | 17 | from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1 |
15 | | -from tests.mock_data import HOME_DATA_RAW, USER_DATA |
| 18 | +from tests.mock_data import HOME_DATA_RAW, TEST_LOCAL_API_HOST, USER_DATA |
16 | 19 |
|
17 | 20 | _LOGGER = logging.getLogger(__name__) |
18 | 21 |
|
@@ -191,3 +194,43 @@ def mock_rest() -> aioresponses: |
191 | 194 | payload={"api": None, "code": 200, "result": HOME_DATA_RAW, "status": "ok", "success": True}, |
192 | 195 | ) |
193 | 196 | yield mocked |
| 197 | + |
| 198 | + |
| 199 | +@pytest.fixture(name="mock_create_local_connection") |
| 200 | +def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]: |
| 201 | + """Fixture that overrides the transport creation to wire it up to the mock socket.""" |
| 202 | + |
| 203 | + async def create_connection(protocol_factory: Callable[[], Protocol], *args) -> tuple[Any, Any]: |
| 204 | + protocol = protocol_factory() |
| 205 | + |
| 206 | + def handle_write(data: bytes) -> None: |
| 207 | + _LOGGER.debug("Received: %s", data) |
| 208 | + response = request_handler(data) |
| 209 | + if response is not None: |
| 210 | + _LOGGER.debug("Replying with %s", response) |
| 211 | + loop = asyncio.get_running_loop() |
| 212 | + loop.call_soon(protocol.data_received, response) |
| 213 | + |
| 214 | + closed = asyncio.Event() |
| 215 | + |
| 216 | + mock_transport = Mock() |
| 217 | + mock_transport.write = handle_write |
| 218 | + mock_transport.close = closed.set |
| 219 | + mock_transport.is_reading = lambda: not closed.is_set() |
| 220 | + |
| 221 | + return (mock_transport, "proto") |
| 222 | + |
| 223 | + with patch("roborock.api.get_running_loop_or_create_one") as mock_loop: |
| 224 | + mock_loop.return_value.create_connection.side_effect = create_connection |
| 225 | + yield |
| 226 | + |
| 227 | + |
| 228 | +@pytest.fixture(name="local_client") |
| 229 | +def local_client_fixture(mock_create_local_connection: None) -> Generator[RoborockLocalClientV1, None, None]: |
| 230 | + home_data = HomeData.from_dict(HOME_DATA_RAW) |
| 231 | + device_info = DeviceData( |
| 232 | + device=home_data.devices[0], |
| 233 | + model=home_data.products[0].model, |
| 234 | + host=TEST_LOCAL_API_HOST, |
| 235 | + ) |
| 236 | + yield RoborockLocalClientV1(device_info) |
0 commit comments