Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
DEFAULT_MAX_TURNS = 10
DEFAULT_MAX_MANIFEST_ENTRY_CONCURRENCY = 4
DEFAULT_MAX_LOCAL_DIR_FILE_CONCURRENCY = 4
DEFAULT_MAX_ARCHIVE_INPUT_BYTES = 1024 * 1024 * 1024
DEFAULT_MAX_ARCHIVE_EXTRACTED_BYTES = 4 * 1024 * 1024 * 1024
DEFAULT_MAX_ARCHIVE_MEMBERS = 100_000


def _default_trace_include_sensitive_data() -> bool:
Expand Down Expand Up @@ -129,6 +132,40 @@ def validate(self) -> None:
raise ValueError("concurrency_limits.local_dir_files must be at least 1")


@dataclass
class SandboxArchiveLimits:
"""Resource limits for sandbox archive extraction."""

max_input_bytes: int | None = DEFAULT_MAX_ARCHIVE_INPUT_BYTES
"""Maximum archive input bytes accepted by `BaseSandboxSession.extract()`.

Set to `None` to disable this input-size limit.
"""

max_extracted_bytes: int | None = DEFAULT_MAX_ARCHIVE_EXTRACTED_BYTES
"""Maximum declared bytes that an archive may extract.

Set to `None` to disable this extracted-size limit.
"""

max_members: int | None = DEFAULT_MAX_ARCHIVE_MEMBERS
"""Maximum number of extractable archive members.

Set to `None` to disable this member-count limit.
"""

def __post_init__(self) -> None:
self.validate()

def validate(self) -> None:
if self.max_input_bytes is not None and self.max_input_bytes < 1:
raise ValueError("archive_limits.max_input_bytes must be at least 1")
if self.max_extracted_bytes is not None and self.max_extracted_bytes < 1:
raise ValueError("archive_limits.max_extracted_bytes must be at least 1")
if self.max_members is not None and self.max_members < 1:
raise ValueError("archive_limits.max_members must be at least 1")


@dataclass
class SandboxRunConfig:
"""Grouped sandbox runtime configuration for `Runner`."""
Expand All @@ -154,6 +191,13 @@ class SandboxRunConfig:
concurrency_limits: SandboxConcurrencyLimits = field(default_factory=SandboxConcurrencyLimits)
"""Concurrency limits for sandbox materialization work."""

archive_limits: SandboxArchiveLimits | None = None
"""Resource limits for sandbox archive extraction.

Set to `None` to preserve the default behavior with no SDK archive resource limits.
Use `SandboxArchiveLimits()` to enable SDK defaults.
"""


@dataclass
class RunConfig:
Expand Down Expand Up @@ -316,6 +360,7 @@ class RunOptions(TypedDict, Generic[TContext]):
"ReasoningItemIdPolicy",
"RunConfig",
"RunOptions",
"SandboxArchiveLimits",
"SandboxConcurrencyLimits",
"SandboxRunConfig",
"ToolExecutionConfig",
Expand Down
3 changes: 2 additions & 1 deletion src/agents/sandbox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from ..run_config import SandboxConcurrencyLimits, SandboxRunConfig
from ..run_config import SandboxArchiveLimits, SandboxConcurrencyLimits, SandboxRunConfig
from .capabilities import Capability
from .config import MemoryGenerateConfig, MemoryLayoutConfig, MemoryReadConfig
from .entries import Dir, LocalFile
Expand Down Expand Up @@ -50,6 +50,7 @@
"RemoteSnapshotSpec",
"Permissions",
"SandboxAgent",
"SandboxArchiveLimits",
"SandboxPathGrant",
"SandboxConcurrencyLimits",
"SandboxError",
Expand Down
23 changes: 18 additions & 5 deletions src/agents/sandbox/runtime_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Generic, cast

from ..agent import Agent
from ..run_config import SandboxConcurrencyLimits, SandboxRunConfig
from ..run_config import SandboxArchiveLimits, SandboxConcurrencyLimits, SandboxRunConfig
from ..run_context import TContext
from ..run_state import (
RunState,
Expand Down Expand Up @@ -286,10 +286,12 @@ async def _create_resources(
) -> _SandboxSessionResources:
sandbox_config = self._require_sandbox_config()
concurrency_limits = self._resolve_concurrency_limits()
archive_limits = self._resolve_archive_limits()
if sandbox_config.session is not None:
self._configure_session_materialization(
self._configure_session(
sandbox_config.session,
concurrency_limits=concurrency_limits,
archive_limits=archive_limits,
)
running = await sandbox_config.session.running()
manifest_update = self._process_live_session_manifest(
Expand Down Expand Up @@ -341,9 +343,10 @@ async def _create_resources(
)
with span_cm:
resumed_session = await client.resume(explicit_state)
self._configure_session_materialization(
self._configure_session(
resumed_session,
concurrency_limits=concurrency_limits,
archive_limits=archive_limits,
)
return _SandboxSessionResources(
session=resumed_session,
Expand Down Expand Up @@ -383,9 +386,10 @@ async def _create_resources(
manifest=effective_manifest,
options=options,
)
self._configure_session_materialization(
self._configure_session(
session,
concurrency_limits=concurrency_limits,
archive_limits=archive_limits,
)
self._ensure_session_manifest_has_run_as_user(session=session, agent=agent)
return _SandboxSessionResources(session=session, client=client, owns_session=True)
Expand All @@ -396,13 +400,22 @@ def _resolve_concurrency_limits(self) -> SandboxConcurrencyLimits:
limits.validate()
return limits

def _configure_session_materialization(
def _resolve_archive_limits(self) -> SandboxArchiveLimits | None:
sandbox_config = self._require_sandbox_config()
limits = sandbox_config.archive_limits
if limits is not None:
limits.validate()
return limits

def _configure_session(
self,
session: BaseSandboxSession,
*,
concurrency_limits: SandboxConcurrencyLimits,
archive_limits: SandboxArchiveLimits | None,
) -> None:
session._set_concurrency_limits(concurrency_limits)
session._set_archive_limits(archive_limits)

def _resume_state_payload_for_agent(
self,
Expand Down
173 changes: 167 additions & 6 deletions src/agents/sandbox/session/archive_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Literal, cast

from ...run_config import SandboxArchiveLimits
from ..errors import ExecNonZeroError, WorkspaceArchiveWriteError
from ..files import EntryKind, FileEntry
from ..util.tar_utils import UnsafeTarMemberError, safe_tar_member_rel_path, validate_tarfile
from ..util.tar_utils import UnsafeTarMemberError, safe_tar_member_rel_path


class UnsafeZipMemberError(ValueError):
Expand All @@ -24,6 +25,24 @@ def __init__(self, *, member: str, reason: str) -> None:
self.reason = reason


class ArchiveResourceLimitError(ValueError):
"""Raised when an archive exceeds extraction resource limits."""

def __init__(
self,
*,
reason: str,
limit: int,
actual: int,
member: str | None = None,
) -> None:
super().__init__(reason)
self.reason = reason
self.limit = limit
self.actual = actual
self.member = member


class WorkspaceArchiveExtractor:
def __init__(
self,
Expand All @@ -42,12 +61,16 @@ async def extract_tar_archive(
archive_path: Path,
destination_root: Path,
data: io.IOBase,
archive_limits: SandboxArchiveLimits | None = None,
) -> None:
child_entry_cache: dict[Path, dict[str, EntryKind]] = {}
try:
with tarfile.open(fileobj=data, mode="r:*") as archive:
validate_tarfile(archive, allow_symlinks=False)
for member in archive.getmembers():
with tarfile.open(fileobj=data, mode="r|*") as archive:
validate_tar_archive_for_extraction(archive, archive_limits=archive_limits)

data.seek(0)
with tarfile.open(fileobj=data, mode="r|*") as archive:
for member in archive:
rel_path = safe_tar_member_rel_path(member)
if rel_path is None:
continue
Expand Down Expand Up @@ -99,6 +122,12 @@ async def extract_tar_archive(
context={"member": e.member, "reason": e.reason},
cause=e,
) from e
except ArchiveResourceLimitError as e:
raise WorkspaceArchiveWriteError(
path=archive_path,
context=_archive_resource_limit_context(e),
cause=e,
) from e
except (tarfile.TarError, OSError) as e:
raise WorkspaceArchiveWriteError(path=archive_path, cause=e) from e

Expand All @@ -108,12 +137,13 @@ async def extract_zip_archive(
archive_path: Path,
destination_root: Path,
data: io.IOBase,
archive_limits: SandboxArchiveLimits | None = None,
) -> None:
child_entry_cache: dict[Path, dict[str, EntryKind]] = {}
try:
with zipfile_compatible_stream(data) as zip_data:
with zipfile.ZipFile(zip_data) as archive:
validate_zipfile(archive)
validate_zipfile(archive, archive_limits=archive_limits)
for member in archive.infolist():
rel_path = safe_zip_member_rel_path(member)
if rel_path is None:
Expand Down Expand Up @@ -158,6 +188,12 @@ async def extract_zip_archive(
context={"member": e.member, "reason": e.reason},
cause=e,
) from e
except ArchiveResourceLimitError as e:
raise WorkspaceArchiveWriteError(
path=archive_path,
context=_archive_resource_limit_context(e),
cause=e,
) from e
except ValueError as e:
raise WorkspaceArchiveWriteError(path=archive_path, cause=e) from e
except (zipfile.BadZipFile, OSError) as e:
Expand Down Expand Up @@ -302,15 +338,140 @@ def safe_zip_member_rel_path(member: zipfile.ZipInfo) -> Path | None:
return Path(*rel.parts)


def validate_zipfile(archive: zipfile.ZipFile) -> None:
def _archive_resource_limit_context(error: ArchiveResourceLimitError) -> dict[str, object]:
context: dict[str, object] = {
"reason": error.reason,
"limit": error.limit,
"actual": error.actual,
}
if error.member is not None:
context["member"] = error.member
return context


def _check_archive_member_count(
*,
count: int,
member: str,
archive_limits: SandboxArchiveLimits | None,
) -> None:
if archive_limits is None or archive_limits.max_members is None:
return

if count > archive_limits.max_members:
raise ArchiveResourceLimitError(
reason="archive member count exceeds limit",
limit=archive_limits.max_members,
actual=count,
member=member,
)


def _check_archive_extracted_bytes(
*,
total: int,
member: str,
archive_limits: SandboxArchiveLimits | None,
) -> None:
if archive_limits is None or archive_limits.max_extracted_bytes is None:
return

if total > archive_limits.max_extracted_bytes:
raise ArchiveResourceLimitError(
reason="archive extracted size exceeds limit",
limit=archive_limits.max_extracted_bytes,
actual=total,
member=member,
)


def validate_tar_archive_for_extraction(
archive: tarfile.TarFile,
*,
archive_limits: SandboxArchiveLimits | None = None,
) -> None:
members_by_rel_path: dict[Path, tarfile.TarInfo] = {}
descendant_by_parent_path: dict[Path, tarfile.TarInfo] = {}
member_count = 0
extracted_bytes = 0

for member in archive:
rel_path = safe_tar_member_rel_path(member)
if rel_path is None:
continue

member_count += 1
_check_archive_member_count(
count=member_count,
member=member.name,
archive_limits=archive_limits,
)
if member.isreg():
extracted_bytes += max(member.size, 0)
_check_archive_extracted_bytes(
total=extracted_bytes,
member=member.name,
archive_limits=archive_limits,
)

previous = members_by_rel_path.get(rel_path)
if previous is not None and not (previous.isdir() and member.isdir()):
raise UnsafeTarMemberError(
member=member.name,
reason=f"duplicate archive path: {rel_path.as_posix()}",
)

for parent in rel_path.parents:
if parent == Path():
break
parent_member = members_by_rel_path.get(parent)
if parent_member is not None and not parent_member.isdir():
raise UnsafeTarMemberError(
member=member.name,
reason=f"archive path descends through non-directory: {parent.as_posix()}",
)

if not member.isdir():
descendant = descendant_by_parent_path.get(rel_path)
if descendant is not None:
raise UnsafeTarMemberError(
member=descendant.name,
reason=f"archive path descends through non-directory: {rel_path.as_posix()}",
)

members_by_rel_path[rel_path] = member
for parent in rel_path.parents:
if parent == Path():
break
descendant_by_parent_path.setdefault(parent, member)


def validate_zipfile(
archive: zipfile.ZipFile,
*,
archive_limits: SandboxArchiveLimits | None = None,
) -> None:
members_by_rel_path: dict[Path, zipfile.ZipInfo] = {}
members: list[tuple[zipfile.ZipInfo, Path]] = []
extracted_bytes = 0

for member in archive.infolist():
rel_path = safe_zip_member_rel_path(member)
if rel_path is None:
continue

_check_archive_member_count(
count=len(members) + 1,
member=member.filename,
archive_limits=archive_limits,
)
extracted_bytes += max(member.file_size, 0)
_check_archive_extracted_bytes(
total=extracted_bytes,
member=member.filename,
archive_limits=archive_limits,
)

previous = members_by_rel_path.get(rel_path)
if previous is not None and not (previous.is_dir() and member.is_dir()):
raise UnsafeZipMemberError(
Expand Down
Loading