Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions plugins/another/another_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def write(ctx: typer.Context, message: str) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smpclient.request(AnotherWrite(d=message))
print(r)
Expand All @@ -90,7 +90,7 @@ def read(ctx: typer.Context) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smpclient.request(AnotherRead())
print(r)
Expand Down
4 changes: 2 additions & 2 deletions plugins/example_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def write(ctx: typer.Context, message: str) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smpclient.request(ExampleWrite(d=message))
print(r)
Expand All @@ -90,7 +90,7 @@ def read(ctx: typer.Context) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smpclient.request(ExampleRead())
print(r)
Expand Down
15 changes: 8 additions & 7 deletions smpmgr/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,22 @@ def get_custom_smpclient(options: Options, smp_client_cls: Type[TSMPClient]) ->
kwargs['line_buffers'] = 1
if options.baudrate is not None:
kwargs['baudrate'] = options.baudrate
return smp_client_cls(SMPSerialTransport(**kwargs), options.transport.port)
return smp_client_cls(SMPSerialTransport(**kwargs), options.transport.port, options.timeout)
elif options.transport.ble is not None:
logger.info(f"Initializing SMPClient with the SMPBLETransport, {options.transport.ble=}")
return smp_client_cls(
SMPBLETransport(),
options.transport.ble,
options.timeout,
)
elif options.transport.ip is not None:
logger.info(f"Initializing SMPClient with the SMPUDPTransport, {options.transport.ip=}")
if options.mtu is not None:
return smp_client_cls(SMPUDPTransport(mtu=options.mtu), options.transport.ip)
return smp_client_cls(
SMPUDPTransport(mtu=options.mtu), options.transport.ip, options.timeout
)
else:
return smp_client_cls(SMPUDPTransport(), options.transport.ip)
return smp_client_cls(SMPUDPTransport(), options.transport.ip, options.timeout)
else:
typer.echo(
f"A transport option is required; "
Expand All @@ -85,15 +88,15 @@ def get_smpclient(options: Options) -> SMPClient:
return get_custom_smpclient(options, SMPClient)


async def connect_with_spinner(smpclient: SMPClient, timeout_s: float) -> None:
async def connect_with_spinner(smpclient: SMPClient) -> None:
"""Spin while connecting to the SMP Server; raises `typer.Exit` if connection fails."""
with Progress(
SpinnerColumn(), TextColumn("[progress.description]{task.description}")
) as progress:
connect_task_description = f"Connecting to {smpclient._address}..."
connect_task = progress.add_task(description=connect_task_description, total=None)
try:
await smpclient.connect(timeout_s)
await smpclient.connect()
progress.update(
connect_task, description=f"{connect_task_description} OK", completed=True
)
Expand All @@ -111,7 +114,6 @@ async def connect_with_spinner(smpclient: SMPClient, timeout_s: float) -> None:

async def smp_request(
smpclient: SMPClient,
options: Options,
request: SMPRequest[TRep, TEr1, TEr2],
description: str | None = None,
timeout_s: float | None = None,
Expand All @@ -120,7 +122,6 @@ async def smp_request(
SpinnerColumn(), TextColumn("[progress.description]{task.description}")
) as progress:
description = description or f"Waiting for response to {request.__class__.__name__}..."
timeout_s = timeout_s if timeout_s is not None else options.timeout
task = progress.add_task(description=description, total=None)
try:
r = await smpclient.request(request, timeout_s)
Expand Down
8 changes: 4 additions & 4 deletions smpmgr/enumeration_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def get_supported_groups(ctx: typer.Context) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
r = await smp_request(smpclient, options, ListSupportedGroups(), "Waiting for supported groups...") # type: ignore # noqa
await connect_with_spinner(smpclient)
r = await smp_request(smpclient, ListSupportedGroups(), "Waiting for supported groups...") # type: ignore # noqa
print(r)

asyncio.run(f())
Expand All @@ -41,8 +41,8 @@ def get_group_details(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
r = await smp_request(smpclient, options, GroupDetails(groups=groups), "Waiting for group details...") # type: ignore # noqa
await connect_with_spinner(smpclient)
r = await smp_request(smpclient, GroupDetails(groups=groups), "Waiting for group details...") # type: ignore # noqa
print(r)

asyncio.run(f())
16 changes: 8 additions & 8 deletions smpmgr/file_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def get_supported_hash_types(ctx: typer.Context) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smp_request(smpclient, options, SupportedFileHashChecksumTypes(), "Waiting for supported hash types...") # type: ignore # noqa
r = await smp_request(smpclient, SupportedFileHashChecksumTypes(), "Waiting for supported hash types...") # type: ignore # noqa

if error(r):
print(r)
Expand All @@ -65,9 +65,9 @@ def get_hash(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smp_request(smpclient, options, FileHashChecksum(name=file), "Waiting for hash...") # type: ignore # noqa
r = await smp_request(smpclient, FileHashChecksum(name=file), "Waiting for hash...") # type: ignore # noqa

if error(r) or success(r):
print(r)
Expand All @@ -87,9 +87,9 @@ def read_size(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smp_request(smpclient, options, FileStatus(name=file), "Waiting for file size...") # type: ignore # noqa
r = await smp_request(smpclient, FileStatus(name=file), "Waiting for file size...") # type: ignore # noqa

if error(r):
print(r)
Expand Down Expand Up @@ -146,7 +146,7 @@ def upload(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)
with open(file, "rb") as f:
await upload_with_progress_bar(smpclient, f, destination)

Expand All @@ -173,7 +173,7 @@ def download(
destination = Path(Path(file).name) if destination is None else destination

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

with Progress(
SpinnerColumn(), TextColumn("[progress.description]{task.description}")
Expand Down
15 changes: 6 additions & 9 deletions smpmgr/image_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def state_read(ctx: typer.Context) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smp_request(smpclient, options, ImageStatesRead(), "Waiting for image states...")
r = await smp_request(smpclient, ImageStatesRead(), "Waiting for image states...")

if error(r):
print(r)
Expand Down Expand Up @@ -90,11 +90,10 @@ def state_write(
hash_bytes = bytes.fromhex(hash) if hash is not None else None

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smp_request(
smpclient,
options,
ImageStatesWrite(hash=hash_bytes, confirm=confirm),
"Waiting for image state write...",
)
Expand Down Expand Up @@ -123,11 +122,9 @@ def erase(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

r = await smp_request(
smpclient, options, ImageErase(slot=slot), "Waiting for image erase..."
)
r = await smp_request(smpclient, ImageErase(slot=slot), "Waiting for image erase...")

if error(r):
print(r)
Expand Down Expand Up @@ -191,7 +188,7 @@ def upload(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)
with open(file, "rb") as f:
await upload_with_progress_bar(smpclient, f, slot)

Expand Down
5 changes: 2 additions & 3 deletions smpmgr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def upgrade(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

with open(file, "rb") as f:
await upload_with_progress_bar(smpclient, f, slot)
Expand All @@ -177,7 +177,6 @@ async def f() -> None:
# mark the new image for testing (swap)
image_states_response = await smp_request(
smpclient,
options,
ImageStatesWrite(hash=image_tlv_sha256.value),
"Marking uploaded image for test upgrade...",
)
Expand All @@ -189,7 +188,7 @@ async def f() -> None:
else:
assert_never(image_states_response)

reset_response = await smp_request(smpclient, options, ResetWrite())
reset_response = await smp_request(smpclient, ResetWrite())
if success(reset_response):
pass
elif error(reset_response):
Expand Down
8 changes: 4 additions & 4 deletions smpmgr/os_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def echo(ctx: typer.Context, message: str) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
r = await smp_request(smpclient, options, EchoWrite(d=message)) # type: ignore
await connect_with_spinner(smpclient)
r = await smp_request(smpclient, EchoWrite(d=message)) # type: ignore
print(r)

asyncio.run(f())
Expand All @@ -33,8 +33,8 @@ def reset(ctx: typer.Context) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
r = await smp_request(smpclient, options, ResetWrite()) # type: ignore
await connect_with_spinner(smpclient)
r = await smp_request(smpclient, ResetWrite()) # type: ignore
print(r)

asyncio.run(f())
6 changes: 3 additions & 3 deletions smpmgr/shell_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def shell(
command: str = typer.Argument(
help="Command string to run, e.g. \"gpio conf gpio@49000000 0 i\""
),
timeout: float = typer.Option(2.0, help="Timeout in seconds for the command to complete"),
timeout: float
| None = typer.Option(None, help="Timeout in seconds for the command to complete"),
verbose: A[
bool, typer.Option("--verbose", help="Print the raw success response") # noqa: F821,F722
] = False,
Expand All @@ -28,11 +29,10 @@ def shell(
smpclient: Final = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

response: Final = await smp_request(
smpclient,
options,
Execute(argv=shlex.split(command)),
f"Waiting response to {command}...",
timeout_s=timeout,
Expand Down
18 changes: 9 additions & 9 deletions smpmgr/stat_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def list_stats(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
r = await smp_request(smpclient, options, ListOfGroups()) # type: ignore
await connect_with_spinner(smpclient)
r = await smp_request(smpclient, ListOfGroups()) # type: ignore

if verbose:
print(r)
Expand Down Expand Up @@ -51,8 +51,8 @@ def smp_svr_stats(ctx: typer.Context) -> None:
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
r = await smp_request(smpclient, options, GroupData(name="smp_svr_stats"))
await connect_with_spinner(smpclient)
r = await smp_request(smpclient, GroupData(name="smp_svr_stats"))
print(r)

asyncio.run(f())
Expand All @@ -68,8 +68,8 @@ def get_group(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
r = await smp_request(smpclient, options, GroupData(name=group_id))
await connect_with_spinner(smpclient)
r = await smp_request(smpclient, GroupData(name=group_id))
print(r)

asyncio.run(f())
Expand All @@ -86,9 +86,9 @@ def fetch_all_groups(
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)

list_response = await smp_request(smpclient, options, ListOfGroups()) # type: ignore
list_response = await smp_request(smpclient, ListOfGroups()) # type: ignore

if not hasattr(list_response, 'stat_list') or not list_response.stat_list:
print("No statistics groups available")
Expand All @@ -97,7 +97,7 @@ async def f() -> None:
groups_data = []

for group_name in list_response.stat_list:
group_data = await smp_request(smpclient, options, GroupData(name=group_name))
group_data = await smp_request(smpclient, GroupData(name=group_name))
groups_data.append({'name': group_name, 'data': group_data})

if verbose:
Expand Down
2 changes: 1 addition & 1 deletion smpmgr/user/intercreate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def upload(
smpclient = get_custom_smpclient(options, ic.ICUploadClient)

async def f() -> None:
await connect_with_spinner(smpclient, options.timeout)
await connect_with_spinner(smpclient)
with open(file, "rb") as f:
await upload_with_progress_bar(smpclient, f, image)

Expand Down
Loading