Skip to content
1 change: 1 addition & 0 deletions modelaudit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,7 @@ def _scan_file_internal(path: str, config: dict[str, Any] | None = None) -> Scan
"safetensors": "safetensors",
"tensorflow_directory": "tf_savedmodel",
"protobuf": "tf_savedmodel",
"tar": "tar",
"zip": "zip",
"onnx": "onnx",
"gguf": "gguf",
Expand Down
224 changes: 197 additions & 27 deletions modelaudit/scanners/tar_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
]

DEFAULT_MAX_TAR_ENTRY_SIZE = 1024 * 1024 * 1024
DEFAULT_MAX_DECOMPRESSED_BYTES = 512 * 1024 * 1024
DEFAULT_MAX_DECOMPRESSION_RATIO = 250.0

_GZIP_MAGIC = b"\x1f\x8b"
_BZIP2_MAGIC = b"BZh"
_XZ_MAGIC = b"\xfd7zXZ\x00"


class TarScanner(BaseScanner):
Expand All @@ -49,17 +55,18 @@ def __init__(self, config: dict[str, Any] | None = None) -> None:
super().__init__(config)
self.max_depth = self.config.get("max_tar_depth", 5)
self.max_entries = self.config.get("max_tar_entries", 10000)
self.max_decompressed_bytes = int(
self.config.get("compressed_max_decompressed_bytes", DEFAULT_MAX_DECOMPRESSED_BYTES),
)
self.max_decompression_ratio = float(
self.config.get("compressed_max_decompression_ratio", DEFAULT_MAX_DECOMPRESSION_RATIO),
)

@classmethod
def can_handle(cls, path: str) -> bool:
if not os.path.isfile(path):
return False

# Check for compound extensions like .tar.gz
path_lower = path.lower()
if not any(path_lower.endswith(ext) for ext in cls.supported_extensions):
return False

try:
return tarfile.is_tarfile(path)
except Exception:
Expand Down Expand Up @@ -170,6 +177,181 @@ def _extract_member_to_tempfile(
assert tmp_path is not None
return tmp_path, total_size

@staticmethod
def _detect_compressed_tar_wrapper(path: str) -> str | None:
"""Detect compressed TAR wrappers by content, not by filename suffix."""
with open(path, "rb") as file_obj:
header = file_obj.read(6)

if header.startswith(_GZIP_MAGIC):
return "gzip"
if header.startswith(_BZIP2_MAGIC):
return "bzip2"
if header.startswith(_XZ_MAGIC):
return "xz"
return None

@staticmethod
def _finalize_tar_stream_size(consumed_size: int) -> int:
"""Return the minimum TAR stream size after EOF blocks and record padding."""
total_size = max(consumed_size + (2 * tarfile.BLOCKSIZE), tarfile.RECORDSIZE)
return ((total_size + tarfile.RECORDSIZE - 1) // tarfile.RECORDSIZE) * tarfile.RECORDSIZE

def _add_compressed_wrapper_limit_check(
self,
result: ScanResult,
*,
passed: bool,
path: str,
message: str,
decompressed_size: int,
compressed_size: int,
compression_codec: str,
actual_ratio: float | None = None,
) -> None:
"""Record compressed-wrapper policy checks with consistent details."""
details: dict[str, Any] = {
"decompressed_size": decompressed_size,
"compressed_size": compressed_size,
"max_decompressed_size": self.max_decompressed_bytes,
"max_ratio": self.max_decompression_ratio,
"compression": compression_codec,
}
if actual_ratio is not None:
details["actual_ratio"] = actual_ratio

result.add_check(
name="Compressed Wrapper Decompression Limits",
passed=passed,
message=message,
severity=None if passed else IssueSeverity.WARNING,
location=path,
details=details,
rule_code=None if passed else "S902",
)

def _preflight_tar_archive(self, path: str, result: ScanResult) -> bool:
"""Stream TAR headers once to enforce entry-count and wrapper-size limits before extraction."""
entry_count = 0
compressed_size = os.path.getsize(path)
compression_codec = self._detect_compressed_tar_wrapper(path)
consumed_size = 0

with tarfile.open(path, "r:*") as tar:
while True:
member = tar.next()
if member is None:
break

entry_count += 1
if entry_count > self.max_entries:
result.add_check(
name="Entry Count Limit Check",
passed=False,
message=f"TAR file contains too many entries ({entry_count} > {self.max_entries})",
rule_code="S902",
severity=IssueSeverity.WARNING,
location=path,
details={"entries": entry_count, "max_entries": self.max_entries},
)
return False

if compression_codec is not None:
consumed_size = max(consumed_size, tar.offset)
estimated_stream_size = self._finalize_tar_stream_size(consumed_size)
actual_ratio = (estimated_stream_size / compressed_size) if compressed_size > 0 else 0.0

if estimated_stream_size > self.max_decompressed_bytes:
self._add_compressed_wrapper_limit_check(
result,
passed=False,
path=path,
message=(
f"Decompressed size exceeded limit "
f"({estimated_stream_size} > {self.max_decompressed_bytes})"
),
decompressed_size=estimated_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)
return False

if compressed_size > 0 and actual_ratio > self.max_decompression_ratio:
self._add_compressed_wrapper_limit_check(
result,
passed=False,
path=path,
message=(
"Decompression ratio exceeded limit "
f"({actual_ratio:.1f}x > {self.max_decompression_ratio:.1f}x)"
),
decompressed_size=estimated_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)
return False

result.add_check(
name="Entry Count Limit Check",
passed=True,
message=f"Entry count ({entry_count}) is within limits",
location=path,
details={"entries": entry_count, "max_entries": self.max_entries},
rule_code=None,
)

if compression_codec is not None:
final_stream_size = self._finalize_tar_stream_size(max(consumed_size, tar.offset))
actual_ratio = (final_stream_size / compressed_size) if compressed_size > 0 else 0.0

if final_stream_size > self.max_decompressed_bytes:
self._add_compressed_wrapper_limit_check(
result,
passed=False,
path=path,
message=(
f"Decompressed size exceeded limit ({final_stream_size} > {self.max_decompressed_bytes})"
),
decompressed_size=final_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)
return False

if compressed_size > 0 and actual_ratio > self.max_decompression_ratio:
self._add_compressed_wrapper_limit_check(
result,
passed=False,
path=path,
message=(
"Decompression ratio exceeded limit "
f"({actual_ratio:.1f}x > {self.max_decompression_ratio:.1f}x)"
),
decompressed_size=final_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)
return False

self._add_compressed_wrapper_limit_check(
result,
passed=True,
path=path,
message=(
f"Decompressed size/ratio are within limits ({final_stream_size} bytes, {actual_ratio:.1f}x)"
),
decompressed_size=final_stream_size,
compressed_size=compressed_size,
compression_codec=compression_codec,
actual_ratio=actual_ratio,
)

return True

def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult:
result = ScanResult(scanner_name=self.name)
contents: list[dict[str, Any]] = []
Expand All @@ -195,30 +377,18 @@ def _scan_tar_file(self, path: str, depth: int = 0) -> ScanResult:
rule_code=None, # Passing check
)

if not self._preflight_tar_archive(path, result):
result.metadata["contents"] = contents
result.metadata["file_size"] = os.path.getsize(path)
result.finish(success=not result.has_errors)
return result

with tarfile.open(path, "r:*") as tar:
members = tar.getmembers()
if len(members) > self.max_entries:
result.add_check(
name="Entry Count Limit Check",
passed=False,
message=f"TAR file contains too many entries ({len(members)} > {self.max_entries})",
rule_code="S902",
severity=IssueSeverity.WARNING,
location=path,
details={"entries": len(members), "max_entries": self.max_entries},
)
return result
else:
result.add_check(
name="Entry Count Limit Check",
passed=True,
message=f"Entry count ({len(members)}) is within limits",
location=path,
details={"entries": len(members), "max_entries": self.max_entries},
rule_code=None, # Passing check
)
while True:
member = tar.next()
if member is None:
break

for member in members:
name = member.name
temp_base = os.path.join(tempfile.gettempdir(), "extract_tar")
resolved_name, is_safe = sanitize_archive_path(name, temp_base)
Expand Down
13 changes: 13 additions & 0 deletions modelaudit/utils/file/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickletools
import re
import struct
import tarfile
import zipfile
from pathlib import Path, PurePosixPath

Expand Down Expand Up @@ -215,6 +216,14 @@ def is_torchserve_mar_archive(path: str) -> bool:
return False


def _is_tar_archive(path: str) -> bool:
"""Return whether a path is a TAR archive, including compressed wrappers."""
try:
return tarfile.is_tarfile(path)
except Exception:
return False


def is_zipfile(path: str) -> bool:
"""Check if file is a ZIP by reading the signature."""
file_path = Path(path)
Expand Down Expand Up @@ -598,10 +607,14 @@ def detect_file_format(path: str) -> str:

compression_format = _detect_compression_format(header)
if ext in _COMPRESSED_EXTENSION_CODECS:
if _is_tar_archive(path):
return "tar"
expected_codec = _COMPRESSED_EXTENSION_CODECS[ext]
if compression_format == expected_codec:
return "compressed"
return "unknown"
if _is_tar_archive(path):
return "tar"
# Check ZIP magic first (for .pt/.pth files that are actually zips)
if magic4.startswith(b"PK"):
if ext == ".mar" and is_torchserve_mar_archive(path):
Expand Down
Loading
Loading