Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@
from concurrent.futures import CancelledError
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path, PosixPath
from pathlib import Path
from queue import Queue
from urllib.parse import urlparse

import click
import pexpect
import requests
from jumpstarter_driver_composite.client import CompositeClient
from jumpstarter_driver_opendal.client import FlasherClient, OpendalClient, operator_for_path
from jumpstarter_driver_opendal.client import (
FlasherClient,
OpendalClient,
clean_filename,
operator_for_path,
path_with_query,
)
from jumpstarter_driver_opendal.common import PathBuf
from jumpstarter_driver_pyserial.client import Console
from opendal import Metadata, Operator
Expand Down Expand Up @@ -167,10 +173,10 @@ def flash( # noqa: C901
"http", root="/", endpoint=f"{parsed.scheme}://{parsed.netloc}", token=bearer_token
)
operator_scheme = "http"
path = Path(parsed.path)
path = path_with_query(parsed)
else:
path, operator, operator_scheme = operator_for_path(path)
image_url = self.http.get_url() + "/" + path.name
image_url = self.http.get_url() + "/" + self._filename(path)

# start counting time for the flash operation
start_time = time.time()
Expand Down Expand Up @@ -966,9 +972,9 @@ def _transfer_bg_thread(
original_url: Original URL for HTTP fallback
headers: HTTP headers for requests
"""
self.logger.info(f"Writing image to storage in the background: {src_path}")
filename = self._filename(src_path)
self.logger.info(f"Writing image to storage in the background: {filename}")
try:
filename = Path(src_path).name if isinstance(src_path, (str, os.PathLike)) else src_path.name

if src_operator_scheme == "fs":
file_hash = self._sha256_file(src_operator, src_path)
Expand Down Expand Up @@ -1019,7 +1025,7 @@ def _create_metadata_and_json(
) -> tuple[Metadata | None, str]:
"""Create a metadata json string from a metadata object"""
metadata = None
metadata_dict = {"path": str(src_path)}
metadata_dict = {"path": clean_filename(src_path)}

try:
metadata = src_operator.stat(src_path)
Expand Down Expand Up @@ -1088,8 +1094,8 @@ def dump(
raise NotImplementedError("Dump is not implemented for this driver yet")

def _filename(self, path: PathBuf) -> str:
"""Extract filename from url or path"""
if path.startswith("oci://"):
"""Extract filename from url or path, stripping any query parameters"""
if isinstance(path, str) and path.startswith("oci://"):
oci_path = path[6:] # Remove "oci://" prefix
if ":" in oci_path:
repository, tag = oci_path.rsplit(":", 1)
Expand All @@ -1098,10 +1104,8 @@ def _filename(self, path: PathBuf) -> str:
else:
repo_name = oci_path.split("/")[-1] if "/" in oci_path else oci_path
return repo_name
elif path.startswith(("http://", "https://")):
return urlparse(path).path.split("/")[-1]
else:
return Path(path).name
return clean_filename(path)
Comment thread
raballew marked this conversation as resolved.

def _upload_artifact(self, storage, path: PathBuf, operator: Operator):
"""Upload artifact to storage"""
Expand Down Expand Up @@ -1636,17 +1640,12 @@ def _get_decompression_command(filename_or_url) -> str:
Determine the appropriate decompression command based on file extension

Args:
filename (str): Name of the file to check
filename_or_url (str): Name of the file or URL to check

Returns:
str: Decompression command ('zcat', 'xzcat', or 'cat' for uncompressed)
str: Decompression command ('zcat |', 'xzcat |', or '' for uncompressed)
"""
if type(filename_or_url) is PosixPath:
filename = filename_or_url.name
elif filename_or_url.startswith(("http://", "https://")):
filename = urlparse(filename_or_url).path.split("/")[-1]

filename = filename.lower()
filename = clean_filename(filename_or_url).lower()
if filename.endswith((".gz", ".gzip")):
return "zcat |"
elif filename.endswith(".xz"):
Comment on lines +1648 to 1651
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] _get_decompression_command does not handle .zst extension.

The implementation handles .gz/.gzip and .xz but has no branch for .zst. A .zst compressed image fetched from a signed URL would be treated as uncompressed, silently producing a corrupt flash.

Suggested fix: add a .zst branch returning "zstdcat |" and a corresponding test case.

AI-generated, human reviewed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. Adding .zst/zstdcat support is a valid enhancement but is out of scope for this PR, which focuses specifically on fixing signed URL handling. The .zst gap predates this PR and applies equally to non-signed URLs. I'd suggest tracking this as a separate issue.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the missing .zst extension, but this is a pre-existing gap that was present before this PR. Adding .zst support would change the scope of this fix beyond the original issue (preserving URL query parameters for signed URLs). I would suggest tracking this as a separate issue/PR to keep changes focused.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,19 +242,24 @@ def stop(self):
def get_url(self):
return "http://exporter"

client.http = DummyService()
client.tftp = DummyService()
client.call = lambda *args, **kwargs: None
client.http = DummyService() # ty: ignore[unresolved-attribute]
client.tftp = DummyService() # ty: ignore[unresolved-attribute]
client.call = lambda *args, **kwargs: None # ty: ignore[invalid-assignment]

captured = {}

def capture_perform(*args):
captured["image_url"] = args[3]
captured["should_download_to_httpd"] = args[4]
captured["oci_username"] = args[14]
captured["oci_password"] = args[15]
def capture_perform(
partition, block_device, path, image_url, should_download_to_httpd,
storage_thread, error_queue, cacert_file, insecure_tls, headers,
bearer_token, method, fls_version, fls_binary_url,
oci_username, oci_password, power_off=True,
):
captured["image_url"] = image_url
captured["should_download_to_httpd"] = should_download_to_httpd
captured["oci_username"] = oci_username
captured["oci_password"] = oci_password

client._perform_flash_operation = capture_perform
client._perform_flash_operation = capture_perform # ty: ignore[invalid-assignment]

client.flash(
"https://example.com/image.raw.xz",
Expand Down Expand Up @@ -428,6 +433,144 @@ def test_categorize_exception_preserves_cause_for_wrapped_exceptions():
assert "File not found" in str(result)


def test_filename_strips_query_params_from_url_path():
"""Test _filename strips query parameters from paths with signed URL params"""
client = MockFlasherClient()

# Full HTTP URL
assert client._filename("https://cdn.example.com/images/image.raw.xz") == "image.raw.xz"

# Full HTTP URL with query parameters (e.g. CloudFront signed URL)
assert (
client._filename("https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz")
== "image.raw.xz"
)

# Path string with query parameters (as returned by operator_for_path after fix)
assert client._filename("/images/image.raw.xz?Expires=123&Signature=abc") == "image.raw.xz"

# Plain path without query parameters
assert client._filename("/images/image.raw.xz") == "image.raw.xz"

# OCI path
assert client._filename("oci://quay.io/org/myimage:latest") == "myimage-latest"


def test_decompression_command_with_query_params():
"""Test _get_decompression_command handles paths with query parameters"""
from pathlib import PosixPath

from .client import _get_decompression_command

# Standard PosixPath
assert _get_decompression_command(PosixPath("/images/image.raw.xz")) == "xzcat |"
assert _get_decompression_command(PosixPath("/images/image.raw.gz")) == "zcat |"
assert _get_decompression_command(PosixPath("/images/image.raw")) == ""

# Full HTTP URL
assert _get_decompression_command("https://cdn.example.com/images/image.raw.xz") == "xzcat |"

# String path with query parameters (as returned by operator_for_path for signed URLs)
assert _get_decompression_command("/images/image.raw.xz?Expires=123&Signature=abc") == "xzcat |"
assert _get_decompression_command("/images/image.raw.gz?Expires=123") == "zcat |"
assert _get_decompression_command("/images/image.raw?Expires=123") == ""


def test_flash_signed_url_preserves_query_params():
"""Test that flash with a signed HTTP URL preserves query parameters for image_url"""
client = MockFlasherClient()

class DummyService:
def __init__(self):
self.storage = object()

def start(self):
pass

def stop(self):
pass

def get_url(self):
return "http://exporter"

client.http = DummyService() # ty: ignore[unresolved-attribute]
client.tftp = DummyService() # ty: ignore[unresolved-attribute]
client.call = lambda *args, **kwargs: None # ty: ignore[invalid-assignment]

captured = {}

def capture_perform(partition, block_device, path, image_url, should_download_to_httpd, *rest):
captured["image_url"] = image_url
captured["should_download_to_httpd"] = should_download_to_httpd

client._perform_flash_operation = capture_perform # ty: ignore[invalid-assignment]

# Direct HTTP URL with query params (no force_exporter_http) should preserve full URL
signed_url = "https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz"
client.flash(signed_url, method="fls", fls_version="")

assert captured["image_url"] == signed_url
assert captured["should_download_to_httpd"] is False
Comment thread
raballew marked this conversation as resolved.
Comment thread
raballew marked this conversation as resolved.


def test_flash_bearer_token_signed_url_preserves_query_params():
"""Test that flash with force_exporter_http=True and bearer token preserves query params.

When a signed URL is used with a bearer token, the flash() method enters the
bearer token code path (lines 162-174 in client.py) which reconstructs the path
from parsed.path + '?' + parsed.query. This test verifies query params are preserved
and the path passed to the storage thread is correct.
"""
client = MockFlasherClient()

class DummyService:
def __init__(self):
self.storage = object()

def start(self):
pass

def stop(self):
pass

def get_url(self):
return "http://exporter"

def get_host(self):
return "127.0.0.1"

client.http = DummyService() # ty: ignore[unresolved-attribute]
client.tftp = DummyService() # ty: ignore[unresolved-attribute]
client.call = lambda *args, **kwargs: None # ty: ignore[invalid-assignment]

captured = {}

def capture_perform(partition, block_device, path, image_url, should_download_to_httpd, *rest):
captured["path"] = path
captured["image_url"] = image_url
captured["should_download_to_httpd"] = should_download_to_httpd

client._perform_flash_operation = capture_perform # ty: ignore[invalid-assignment]
# Mock the background transfer thread to prevent it from actually running
client._transfer_bg_thread = lambda *args, **kwargs: None # ty: ignore[invalid-assignment]

signed_url = "https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz"
client.flash(
signed_url,
force_exporter_http=True,
bearer_token="test-token-123",
method="fls",
fls_version="",
)

# With force_exporter_http=True and bearer_token, should download to httpd
assert captured["should_download_to_httpd"] is True
# The path should have query params preserved (reconstructed from parsed.path + '?' + parsed.query)
assert captured["path"] == "/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz"
# The image_url should point to the exporter with the clean filename (no query params)
assert captured["image_url"] == "http://exporter/image.raw.xz"


def test_resolve_flash_parameters():
"""Test flash parameter resolution for single file, partitions, and error cases"""
client = MockFlasherClient()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,42 @@ async def aclose(self):
pass


def clean_filename(path: PathBuf) -> str:
"""Extract a clean filename from a path or URL, stripping query parameters.

Handles paths returned by operator_for_path() which may contain
query parameters for signed URLs (e.g. /path/to/image.raw.xz?Expires=...&Signature=...).
"""
path_str = str(path)
if path_str.startswith(("http://", "https://")):
return urlparse(path_str).path.split("/")[-1]
if "?" in path_str:
path_str = path_str.split("?", 1)[0]
return Path(path_str).name


def path_with_query(parsed_url) -> str:
"""Reconstruct path preserving query parameters for signed URL support."""
if parsed_url.query:
return f"{parsed_url.path}?{parsed_url.query}"
return parsed_url.path


def operator_for_path(path: PathBuf) -> tuple[PathBuf, Operator, str]:
"""Create an operator for the given path
"""Create an operator for the given path.

For HTTP URLs, query parameters are preserved in the returned path so that
signed URLs (e.g. CloudFront with Expires/Signature/Key-Pair-Id) work correctly.

Return a tuple of:
- the path
- the path (str for HTTP, Path for filesystem)
- the operator for the given path
- the scheme of the operator.
- the scheme of the operator
"""
if type(path) is str and path.startswith(("http://", "https://")):
parsed_url = urlparse(path)
operator = Operator("http", root="/", endpoint=f"{parsed_url.scheme}://{parsed_url.netloc}")
return Path(parsed_url.path), operator, "http"
return path_with_query(parsed_url), operator, "http"
else:
return Path(path).resolve(), Operator("fs", root="/"), "fs"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,66 @@ def test_copy_and_rename_tracking(tmp_path):
assert "copied_dir" in created_paths
assert "renamed_dir" in created_paths
assert len(created_paths) == 4


def test_clean_filename():
"""Test clean_filename extracts filenames and strips query parameters"""
from pathlib import PosixPath

from .client import clean_filename

# Plain filesystem path
assert clean_filename("/images/image.raw.xz") == "image.raw.xz"
assert clean_filename(PosixPath("/images/image.raw.xz")) == "image.raw.xz"

# Filesystem path with query params (as returned by operator_for_path for signed URLs)
assert clean_filename("/images/image.raw.xz?Expires=123&Signature=abc") == "image.raw.xz"

# Full HTTP URL without query params
assert clean_filename("https://cdn.example.com/images/image.raw.xz") == "image.raw.xz"
assert clean_filename("http://cdn.example.com/images/image.raw.xz") == "image.raw.xz"

# Full HTTP URL with query params (e.g. CloudFront signed URL)
assert (
clean_filename("https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz")
== "image.raw.xz"
)

# Edge case: no directory component
assert clean_filename("image.raw.xz") == "image.raw.xz"
assert clean_filename("image.raw.xz?Expires=123") == "image.raw.xz"

# Edge case: compressed extensions
assert clean_filename("/path/to/image.raw.gz?token=abc") == "image.raw.gz"
assert clean_filename("/path/to/image.raw.gzip?token=abc") == "image.raw.gzip"

# Edge case: query params with unencoded slashes (e.g. base64 signatures)
assert clean_filename("/images/image.raw.xz?Expires=123&Signature=abc/def/ghi") == "image.raw.xz"


def test_operator_for_path_preserves_query_params():
"""Test that operator_for_path preserves query parameters for HTTP URLs"""
from .client import operator_for_path

# HTTP URL without query parameters
path, operator, scheme = operator_for_path("https://cdn.example.com/images/image.raw.xz")
assert scheme == "http"
assert path == "/images/image.raw.xz"

# HTTP URL with query parameters (e.g. CloudFront signed URL)
path, operator, scheme = operator_for_path(
"https://cdn.example.com/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz"
)
assert scheme == "http"
assert path == "/images/image.raw.xz?Expires=123&Signature=abc&Key-Pair-Id=xyz"
assert "Expires=123" in path
assert "Signature=abc" in path
assert "Key-Pair-Id=xyz" in path

# Filesystem path (use resolve() for the expected value since macOS
# resolves /tmp to /private/tmp)
from pathlib import Path

path, operator, scheme = operator_for_path("/tmp/image.raw.xz")
assert scheme == "fs"
assert path == Path("/tmp/image.raw.xz").resolve()
Loading
Loading