Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/smpclient/transport/serial/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Serial SMPTransports.

In addition to UART, these transports can be used with USB CDC ACM and CAN.
"""

from smpclient.transport.serial.encoded import SMPSerialTransport as SMPSerialTransport
from smpclient.transport.serial.unencoded import SMPSerialRawTransport as SMPSerialRawTransport
155 changes: 155 additions & 0 deletions src/smpclient/transport/serial/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Shared connection management for the encoded and unencoded serial transports."""

import asyncio
import logging
from contextlib import contextmanager
from time import monotonic
from typing import Final, Generator, final

try:
from serial import Serial, SerialException
except ModuleNotFoundError as e:
if e.name == "serial":
raise ImportError(
"Serial transport requires the 'serial' extra. Use smpclient[serial]"
) from e
raise
from typing_extensions import override

from smpclient.transport import SMPTransport, SMPTransportDisconnected

logger = logging.getLogger(__name__)


class _SerialTransportBase(SMPTransport):
"""Connection-management base class for serial-port-backed SMP transports.

Holds the `pyserial` `Serial` instance, the open/retry connect loop, disconnect,
and the small TX/RX helpers that wrap `SerialException` into
`SMPTransportDisconnected`.

Subclasses implement `send` and `receive` with their framing of choice and may
override `_reset_state` to clear per-connection state on `connect`.
"""

_POLLING_INTERVAL_S: Final = 0.005
_CONNECTION_RETRY_INTERVAL_S: Final = 0.500

def __init__(
self,
baudrate: int = 115200,
bytesize: int = 8,
parity: str = "N",
stopbits: float = 1,
timeout: float | None = None,
xonxoff: bool = False,
rtscts: bool = False,
write_timeout: float | None = None,
dsrdtr: bool = False,
inter_byte_timeout: float | None = None,
exclusive: bool | None = None,
) -> None:
"""Initialize the underlying `pyserial` `Serial` instance.

Args:
baudrate: The baudrate of the serial connection. OK to ignore for
USB CDC ACM.
bytesize: The number of data bits.
parity: The parity setting.
stopbits: The number of stop bits.
timeout: The read timeout.
xonxoff: Enable software flow control.
rtscts: Enable hardware (RTS/CTS) flow control.
write_timeout: The write timeout.
dsrdtr: Enable hardware (DSR/DTR) flow control.
inter_byte_timeout: The inter-byte timeout.
exclusive: Set exclusive access mode (POSIX only). A port cannot be
opened in exclusive access mode if it is already open in
exclusive access mode.
"""
self._conn: Final = Serial(
baudrate=baudrate,
bytesize=bytesize,
parity=parity,
stopbits=stopbits,
timeout=timeout,
xonxoff=xonxoff,
rtscts=rtscts,
write_timeout=write_timeout,
dsrdtr=dsrdtr,
inter_byte_timeout=inter_byte_timeout,
exclusive=exclusive,
)

def _reset_state(self) -> None:
"""Reset any per-connection state. Subclasses override as needed."""

@final
@override
async def connect(self, address: str, timeout_s: float) -> None:
self._reset_state()
self._conn.port = address
logger.debug(f"Connecting to {self._conn.port=}")
start_time: Final = monotonic()
while monotonic() - start_time <= timeout_s:
try:
self._conn.open()
self._conn.reset_input_buffer()
logger.debug(f"Connected to {self._conn.port=}")
return
except SerialException as e:
logger.debug(
f"Failed to connect to {self._conn.port=}: {e}, "
f"retrying in {self._CONNECTION_RETRY_INTERVAL_S} seconds"
)
await asyncio.sleep(self._CONNECTION_RETRY_INTERVAL_S)
Comment on lines +101 to +105
Comment on lines +95 to +105

raise TimeoutError(f"Failed to connect to {address=}")

@final
@override
async def disconnect(self) -> None:
logger.debug(f"Disconnecting from {self._conn.port=}")
self._conn.close()
logger.debug(f"Disconnected from {self._conn.port=}")

@final
@override
async def send_and_receive(self, data: bytes) -> bytes:
await self.send(data)
return await self.receive()

@final
@contextmanager
def _serial_exception_to_disconnected(self) -> Generator[None, None, None]:
"""Translate `SerialException` from `pyserial` to `SMPTransportDisconnected`."""
try:
yield
except SerialException as e:
logger.error(f"Serial exception on {self._conn.port}: {e}")
raise SMPTransportDisconnected(
f"{self.__class__.__name__} disconnected from {self._conn.port}"
) from e

@final
async def _drain_tx(self) -> None:
"""Block until the serial TX buffer is empty.

Fake-async polling until `pyserial` is replaced.
"""
while self._conn.out_waiting > 0:
await asyncio.sleep(self._POLLING_INTERVAL_S)

@final
async def _read_all(self) -> bytes:
"""Return all currently-available bytes (or empty bytes).

Wraps `SerialException` into `SMPTransportDisconnected`. `StopIteration` is
caught to keep mocked `read_all` side-effect lists usable in tests.
"""
try:
return self._conn.read_all() or b""
except StopIteration:
return b""
except SerialException as exc:
raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}") from exc
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
"""A serial SMPTransport.
"""The base64-encoded serial SMPTransport.

In addition to UART, this transport can be used with USB CDC ACM and CAN.
Wraps each SMP packet in a length+CRC frame, base64-encodes it, and terminates
it with a newline. This is what Zephyr calls "SMP over console" -- the framing
shared by `CONFIG_MCUMGR_TRANSPORT_UART` and `CONFIG_MCUMGR_TRANSPORT_SHELL`,
and the only SMP-over-UART option that existed before Zephyr 4.4. Also the
only framing supported by MCUboot serial recovery (`MCUBOOT_SERIAL`), which
unconditionally selects `BASE64`.

For `CONFIG_MCUMGR_TRANSPORT_RAW_UART` servers, use `SMPSerialRawTransport`
from `smpclient.transport.serial.unencoded`.
"""

import asyncio
import logging
import math
import time
from enum import IntEnum, unique
from functools import cached_property
from typing import Final

try:
from serial import Serial, SerialException
except ModuleNotFoundError as e:
if e.name == "serial":
raise ImportError(
"Serial transport requires the 'serial' extra. Use smpclient[serial]"
) from e
raise
from smp import packet as smppacket
from typing_extensions import override

from smpclient.transport import SMPTransport, SMPTransportDisconnected
from smpclient.transport.serial.common import _SerialTransportBase

logger = logging.getLogger(__name__)

Expand All @@ -43,10 +42,7 @@ def _base64_max(size: int) -> int:
return math.floor(3 / 4 * size) - 2


class SMPSerialTransport(SMPTransport):
_POLLING_INTERVAL_S = 0.005
_CONNECTION_RETRY_INTERVAL_S = 0.500

class SMPSerialTransport(_SerialTransportBase):
@unique
class BufferState(IntEnum):
SMP = 0
Expand Down Expand Up @@ -95,22 +91,12 @@ def __init__( # noqa: DOC301
write_timeout: The write timeout.
dsrdtr: Enable hardware (DSR/DTR) flow control.
inter_byte_timeout: The inter-byte timeout.
exclusive: The exclusive access timeout.
exclusive: Set exclusive access mode (POSIX only). A port cannot be
opened in exclusive access mode if it is already open in
exclusive access mode.

"""
if max_smp_encoded_frame_size < line_length * line_buffers:
logger.error(
f"{max_smp_encoded_frame_size=} is less than {line_length=} * {line_buffers=}!"
)
elif max_smp_encoded_frame_size != line_length * line_buffers:
logger.warning(
f"{max_smp_encoded_frame_size=} is not equal to {line_length=} * {line_buffers=}!"
)

self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size
self._line_length: Final = line_length
self._line_buffers: Final = line_buffers
self._conn: Final = Serial(
super().__init__(
baudrate=baudrate,
bytesize=bytesize,
parity=parity,
Expand All @@ -124,6 +110,19 @@ def __init__( # noqa: DOC301
exclusive=exclusive,
)

if max_smp_encoded_frame_size < line_length * line_buffers:
logger.error(
f"{max_smp_encoded_frame_size=} is less than {line_length=} * {line_buffers=}!"
)
elif max_smp_encoded_frame_size != line_length * line_buffers:
logger.warning(
f"{max_smp_encoded_frame_size=} is not equal to {line_length=} * {line_buffers=}!"
)

self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size
self._line_length: Final = line_length
self._line_buffers: Final = line_buffers

self._smp_packet_queue: asyncio.Queue[bytes] = asyncio.Queue()
"""Contains full SMP packets."""
self._serial_buffer = bytearray()
Expand All @@ -135,60 +134,27 @@ def __init__( # noqa: DOC301

logger.debug(f"Initialized {self.__class__.__name__}")

@override
def _reset_state(self) -> None:
"""Reset internal state and queues for a fresh connection."""
self._smp_packet_queue = asyncio.Queue()
self._serial_buffer.clear()
self._buffer = bytearray([])
self._buffer_state = SMPSerialTransport.BufferState.SERIAL

@override
async def connect(self, address: str, timeout_s: float) -> None:
self._reset_state()
self._conn.port = address
logger.debug(f"Connecting to {self._conn.port=}")
start_time: Final = time.time()
while time.time() - start_time <= timeout_s:
try:
self._conn.open()
self._conn.reset_input_buffer()
logger.debug(f"Connected to {self._conn.port=}")
return
except SerialException as e:
logger.debug(
f"Failed to connect to {self._conn.port=}: {e}, "
f"retrying in {SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S} seconds"
)
await asyncio.sleep(SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S)

raise TimeoutError(f"Failed to connect to {address=}")

@override
async def disconnect(self) -> None:
logger.debug(f"Disconnecting from {self._conn.port=}")
self._conn.close()
logger.debug(f"Disconnected from {self._conn.port=}")

@override
async def send(self, data: bytes) -> None:
if len(data) > self.max_unencoded_size:
raise ValueError(
f"Data size {len(data)} exceeds maximum unencoded size {self.max_unencoded_size}"
)
logger.debug(f"Sending {len(data)} bytes")
try:
with self._serial_exception_to_disconnected():
for packet in smppacket.encode(data, line_length=self._line_length):
self._conn.write(packet)
logger.debug(f"Writing encoded packet of size {len(packet)}B; {self._line_length=}")

# fake async until I get around to replacing pyserial
while self._conn.out_waiting > 0:
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)
except SerialException as e:
logger.error(f"Failed to send {len(data)} bytes: {e}")
raise SMPTransportDisconnected(
f"{self.__class__.__name__} disconnected from {self._conn.port}"
)
await self._drain_tx()

logger.debug(f"Sent {len(data)} bytes")

Expand Down Expand Up @@ -242,18 +208,13 @@ async def read_serial(self, delimiter: bytes | None = None) -> bytes:
async def _read_and_process(self, read_until_one_smp_packet: bool) -> None:
"""Reads raw data from serial and processes it into SMP packets and regular serial data."""
while True:
try:
data = self._conn.read_all() or b""
except StopIteration:
data = b""
except SerialException as exc:
raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}")
data = await self._read_all()

if data:
self._buffer.extend(data)
await self._process_buffer()
else:
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)
await asyncio.sleep(self._POLLING_INTERVAL_S)

if read_until_one_smp_packet:
if self._smp_packet_queue.qsize():
Expand Down Expand Up @@ -342,11 +303,6 @@ def _could_be_smp_packet_start(self, byte: int) -> bool:
"""Return True if the given byte value matches the start of any SMP packet delimiter."""
return byte == smppacket.START_DELIMITER[0] or byte == smppacket.CONTINUE_DELIMITER[0]

@override
async def send_and_receive(self, data: bytes) -> bytes:
await self.send(data)
return await self.receive()

@override
@property
def mtu(self) -> int:
Expand Down
Loading
Loading