Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ classifiers = [
dependencies = [
"smp>=4.0.2",
"intelhex>=2.3.0",
"async-timeout>=4.0.3; python_version < '3.11'",
"async-timeout>=5.0.1; python_version < '3.11'",
]

[project.optional-dependencies]
serial = ["pyserial>=3.5"]
ble = ["bleak>=2.0.0"]
ble = ["bleak>=3.0.2,<4"]
udp = []
all = ["smpclient[serial,ble,udp]"]

Expand Down
51 changes: 49 additions & 2 deletions src/smpclient/transport/ble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import logging
import re
import sys
from typing import Final, Protocol, TypeGuard
from collections.abc import Coroutine
from typing import Any, Final, Protocol, TypeGuard, TypeVar
from uuid import UUID

try:
Expand Down Expand Up @@ -66,6 +67,8 @@ class SMPBLETransportNotSMPServer(SMPBLETransportException):

logger = logging.getLogger(__name__)

_T = TypeVar("_T")


class SMPBLETransport(SMPTransport):
"""A Bluetooth Low Energy (BLE) SMPTransport."""
Expand All @@ -84,6 +87,13 @@ def __init__(self, winrt: WinRTClientArgs = {}) -> None:

@override
async def connect(self, address: str, timeout_s: float) -> None:
try:
await asyncio.wait_for(self._connect(address, timeout_s), timeout=timeout_s)
except (Exception, asyncio.CancelledError):
await self._best_effort_disconnect()
raise

async def _connect(self, address: str, timeout_s: float) -> None:
logger.debug(f"Scanning for {address=}")
device: BLEDevice | None = (
await BleakScanner.find_device_by_address(address, timeout=timeout_s)
Expand All @@ -96,6 +106,7 @@ async def connect(self, address: str, timeout_s: float) -> None:
device,
services=(str(SMP_SERVICE_UUID),),
winrt=self._winrt,
timeout=timeout_s,
disconnected_callback=self._set_disconnected_event,
)
else:
Expand Down Expand Up @@ -139,7 +150,9 @@ async def connect(self, address: str, timeout_s: float) -> None:
self._smp_characteristic = smp_characteristic

logger.debug(f"Starting notify on {SMP_CHARACTERISTIC_UUID=}")
await self._client.start_notify(SMP_CHARACTERISTIC_UUID, self._notify_callback)
await self._await_or_disconnect(
self._client.start_notify(SMP_CHARACTERISTIC_UUID, self._notify_callback)
)
logger.debug(f"Started notify on {SMP_CHARACTERISTIC_UUID=}")

@override
Expand Down Expand Up @@ -246,3 +259,37 @@ async def _notify_or_disconnect(self) -> None:
raise SMPTransportDisconnected(
f"{self.__class__.__name__} disconnected from {self._client.address}"
)

async def _await_or_disconnect(self, coro: Coroutine[Any, Any, _T]) -> _T:
"""Await `coro`; raise `SMPTransportDisconnected` if the peer disconnects first.

Guards GATT operations that can hang indefinitely when the peer
disconnects mid-flow (e.g. failed pairing) — see
https://github.com/intercreate/smpmgr/issues/97.
"""
op_task: Final = asyncio.create_task(coro)
disconnected_task: Final = asyncio.create_task(self._disconnected_event.wait())
try:
done, _ = await asyncio.wait(
(op_task, disconnected_task), return_when=asyncio.FIRST_COMPLETED
)
finally:
for task in (op_task, disconnected_task):
if not task.done():
task.cancel()
await asyncio.gather(op_task, disconnected_task, return_exceptions=True)
if disconnected_task in done:
raise SMPTransportDisconnected(
f"{self.__class__.__name__} disconnected from {self._client.address}"
)
return op_task.result()
Comment thread
JPHutchins marked this conversation as resolved.

async def _best_effort_disconnect(self) -> None:
"""Best-effort cleanup after a failed `connect()`; never raises."""
client: Final = getattr(self, "_client", None)
if client is None:
return
try:
await client.disconnect()
except Exception:
logger.warning("Best-effort disconnect after failed connect raised", exc_info=True)
108 changes: 108 additions & 0 deletions tests/test_smp_ble_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from bleak.backends.device import BLEDevice

from smpclient.requests.os_management import EchoWrite
from smpclient.transport import SMPTransportDisconnected
from smpclient.transport.ble import (
MAC_ADDRESS_PATTERN,
SMP_CHARACTERISTIC_UUID,
Expand Down Expand Up @@ -209,3 +210,110 @@ def test_max_unencoded_size_mcumgr_param() -> None:
t._client = MagicMock(spec=BleakClient)
t._smp_server_transport_buffer_size = 9001
assert t.max_unencoded_size == 9001


class _HangingBleakClient:
"""A `BleakClient` stand-in whose `start_notify` never returns.

Reproduces the failure mode reported in intercreate/smpmgr#97: the BlueZ
`StartNotify` D-Bus call hangs indefinitely when the peer disconnects
mid-pairing.
"""

def __new__(cls, *args: object, **kwargs: object) -> "_HangingBleakClient": # type: ignore[misc] # noqa: E501
captured_callback = kwargs.get("disconnected_callback")
client = MagicMock(spec=BleakClient, name="HangingBleakClient")
client._backend = type("Backend", (), {})()
client.connect = AsyncMock(name="connect")

async def _hang(*_a: object, **_kw: object) -> None:
await asyncio.Event().wait() # never fires

client.start_notify = AsyncMock(side_effect=_hang)
client.disconnect = AsyncMock(name="disconnect")
client.address = "00:00:00:00:00:00"
client._captured_disconnected_callback = captured_callback # type: ignore[attr-defined]
return client


@patch(
"smpclient.transport.ble.BleakScanner.find_device_by_address",
return_value=BLEDevice("00:00:00:00:00:00", "name", None),
)
@patch("smpclient.transport.ble.BleakClient", new=_HangingBleakClient)
@pytest.mark.asyncio
async def test_connect_raises_on_peer_disconnect_during_start_notify(
_mock_find_device_by_address: MagicMock,
) -> None:
"""Regression test for intercreate/smpmgr#97.

When the peer disconnects mid-`start_notify` (e.g. failed pairing), `connect()`
must surface `SMPTransportDisconnected` rather than hang.
"""
t = SMPBLETransport()

async def _trip_disconnect_callback() -> None:
# Wait until the transport reaches start_notify and clears the event,
# then simulate the bleak `disconnected_callback` firing.
while t._disconnected_event.is_set():
await asyncio.sleep(0)
await asyncio.sleep(0) # let start_notify await begin
t._set_disconnected_event(t._client)

connect_task = asyncio.create_task(t.connect("00:00:00:00:00:00", 5.0))
trip_task = asyncio.create_task(_trip_disconnect_callback())

with pytest.raises(SMPTransportDisconnected):
await connect_task
await trip_task

# `_best_effort_disconnect` should have been called to release the client.
t._client.disconnect.assert_awaited() # type: ignore[attr-defined]


@patch(
"smpclient.transport.ble.BleakScanner.find_device_by_address",
return_value=BLEDevice("00:00:00:00:00:00", "name", None),
)
@patch("smpclient.transport.ble.BleakClient", new=_HangingBleakClient)
@pytest.mark.asyncio
async def test_connect_raises_on_timeout_during_start_notify(
_mock_find_device_by_address: MagicMock,
) -> None:
"""`connect()` must honor `timeout_s` even when `start_notify` hangs."""
t = SMPBLETransport()
with pytest.raises(asyncio.TimeoutError):
await t.connect("00:00:00:00:00:00", 0.05)
t._client.disconnect.assert_awaited() # type: ignore[attr-defined]


@patch(
"smpclient.transport.ble.BleakScanner.find_device_by_address",
return_value=BLEDevice("00:00:00:00:00:00", "name", None),
)
@patch("smpclient.transport.ble.BleakClient", new=_HangingBleakClient)
@pytest.mark.asyncio
async def test_connect_does_not_leak_tasks_on_external_cancel(
_mock_find_device_by_address: MagicMock,
) -> None:
"""Caller-driven cancellation must not leave `_await_or_disconnect` sub-tasks running."""
t = SMPBLETransport()
Comment thread
JPHutchins marked this conversation as resolved.
tasks_before = {id(task) for task in asyncio.all_tasks()}

connect_task = asyncio.create_task(t.connect("00:00:00:00:00:00", 60.0))
while t._disconnected_event.is_set():
await asyncio.sleep(0) # wait until BleakClient.connect() returned
await asyncio.sleep(0) # let start_notify await begin

connect_task.cancel()
with pytest.raises(asyncio.CancelledError):
await connect_task

# Let any cancellations propagate to the spawned sub-tasks.
for _ in range(5):
await asyncio.sleep(0)

leaked = [
task for task in asyncio.all_tasks() if id(task) not in tasks_before and not task.done()
]
assert not leaked, f"sub-tasks leaked after external cancel: {leaked}"
15 changes: 9 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading