Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ export LLM_API_KEY="your-api-key"
export LLM_API_BASE="your-api-base-url" # if using a local model, e.g. Ollama, LMStudio
export PERPLEXITY_API_KEY="your-api-key" # for search capabilities
export STRIX_REASONING_EFFORT="high" # control thinking effort (default: high, quick scan: medium)
export STRIX_SANDBOX_EXTRA_HOSTS="test.internal.lan=host-gateway" # optional Docker hosts entries
```

> [!NOTE]
Expand Down
1 change: 1 addition & 0 deletions strix/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Config:
strix_runtime_backend = "docker"
strix_sandbox_execution_timeout = "120"
strix_sandbox_connect_timeout = "10"
strix_sandbox_extra_hosts = None

# Telemetry
strix_telemetry = "1"
Expand Down
49 changes: 39 additions & 10 deletions strix/runtime/docker_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import os
import secrets
import socket
import subprocess
import tarfile
import time
from io import BytesIO
from pathlib import Path
from typing import cast
from urllib.parse import urlparse

import docker
import httpx
Expand Down Expand Up @@ -47,14 +51,14 @@ def _find_available_port(self) -> int:

def _get_scan_id(self, agent_id: str) -> str:
try:
from strix.telemetry.tracer import get_global_tracer
from strix.telemetry.tracer import get_global_tracer # noqa: PLC0415

tracer = get_global_tracer()
if tracer and tracer.scan_config:
return str(tracer.scan_config.get("scan_id", "default-scan"))
except (ImportError, AttributeError):
pass
return f"scan-{agent_id.split('-')[0]}"
return f"scan-{agent_id.split('-', maxsplit=1)[0]}"

def _verify_image_available(self, image_name: str, max_retries: int = 3) -> None:
for attempt in range(max_retries):
Expand Down Expand Up @@ -108,6 +112,33 @@ def _wait_for_tool_server(self, max_retries: int = 30, timeout: int = 5) -> None
"Container initialization timed out. Please try again.",
)

def _get_extra_hosts(self) -> dict[str, str]:
extra_hosts = {HOST_GATEWAY_HOSTNAME: "host-gateway"}
configured_hosts = Config.get("strix_sandbox_extra_hosts")
if not configured_hosts:
return extra_hosts

for raw_host_entry in configured_hosts.split(","):
host_entry = raw_host_entry.strip()
if not host_entry:
continue

parts = [part.strip() for part in host_entry.split("=")]
if len(parts) != 2:
raise ValueError(
"STRIX_SANDBOX_EXTRA_HOSTS entries must use hostname=address format"
)

hostname, address = parts
if not hostname or not address:
raise ValueError(
"STRIX_SANDBOX_EXTRA_HOSTS entries must include both hostname and address"
)

extra_hosts[hostname] = address

return extra_hosts
Comment thread
Dawn-Fighter marked this conversation as resolved.

def _create_container(self, scan_id: str, max_retries: int = 2) -> Container:
container_name = f"strix-scan-{scan_id}"
image_name = Config.get("strix_image")
Expand Down Expand Up @@ -150,7 +181,7 @@ def _create_container(self, scan_id: str, max_retries: int = 2) -> Container:
"STRIX_SANDBOX_EXECUTION_TIMEOUT": str(execution_timeout),
"HOST_GATEWAY": HOST_GATEWAY_HOSTNAME,
},
extra_hosts={HOST_GATEWAY_HOSTNAME: "host-gateway"},
extra_hosts=self._get_extra_hosts(),
tty=True,
)

Expand All @@ -164,6 +195,11 @@ def _create_container(self, scan_id: str, max_retries: int = 2) -> Container:
self._tool_server_token = None
self._caido_port = None
time.sleep(2**attempt)
except ValueError as e:
raise SandboxInitializationError(
"Invalid Docker sandbox host mapping",
str(e),
) from e
else:
return container

Expand Down Expand Up @@ -222,9 +258,6 @@ def _get_or_create_container(self, scan_id: str) -> Container:
def _copy_local_directory_to_container(
self, container: Container, local_path: str, target_name: str | None = None
) -> None:
import tarfile
from io import BytesIO

try:
local_path_obj = Path(local_path).resolve()
if not local_path_obj.exists() or not local_path_obj.is_dir():
Expand Down Expand Up @@ -312,8 +345,6 @@ async def get_sandbox_url(self, container_id: str, port: int) -> str:
def _resolve_docker_host(self) -> str:
docker_host = os.getenv("DOCKER_HOST", "")
if docker_host:
from urllib.parse import urlparse

parsed = urlparse(docker_host)
if parsed.scheme in ("tcp", "http", "https") and parsed.hostname:
return parsed.hostname
Expand Down Expand Up @@ -342,8 +373,6 @@ def cleanup(self) -> None:
if container_name is None:
return

import subprocess

subprocess.Popen( # noqa: S603
["docker", "rm", "-f", container_name], # noqa: S607
stdout=subprocess.DEVNULL,
Expand Down
87 changes: 87 additions & 0 deletions tests/runtime/test_docker_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from types import SimpleNamespace
from unittest.mock import MagicMock

import pytest
from docker.errors import NotFound

from strix.runtime import SandboxInitializationError
from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME, DockerRuntime


def test_get_extra_hosts_includes_host_gateway(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("STRIX_SANDBOX_EXTRA_HOSTS", raising=False)

runtime = DockerRuntime.__new__(DockerRuntime)

assert runtime._get_extra_hosts() == {HOST_GATEWAY_HOSTNAME: "host-gateway"}


def test_get_extra_hosts_merges_configured_entries(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv(
"STRIX_SANDBOX_EXTRA_HOSTS",
"test.internal.lan=host-gateway, api.local = 192.168.1.20",
)

runtime = DockerRuntime.__new__(DockerRuntime)

assert runtime._get_extra_hosts() == {
HOST_GATEWAY_HOSTNAME: "host-gateway",
"test.internal.lan": "host-gateway",
"api.local": "192.168.1.20",
}


def test_get_extra_hosts_rejects_invalid_entries(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("STRIX_SANDBOX_EXTRA_HOSTS", "test.internal.lan")

runtime = DockerRuntime.__new__(DockerRuntime)

with pytest.raises(ValueError, match="hostname=address"):
runtime._get_extra_hosts()


def test_get_extra_hosts_rejects_multiple_equals(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("STRIX_SANDBOX_EXTRA_HOSTS", "test.internal.lan==host-gateway")

runtime = DockerRuntime.__new__(DockerRuntime)

with pytest.raises(ValueError, match="hostname=address"):
runtime._get_extra_hosts()


def test_create_container_passes_configured_extra_hosts(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("STRIX_SANDBOX_EXTRA_HOSTS", "test.internal.lan=host-gateway")

run = MagicMock(return_value=object())
containers = SimpleNamespace(get=MagicMock(side_effect=NotFound("missing")), run=run)
runtime = DockerRuntime.__new__(DockerRuntime)
runtime.client = SimpleNamespace(containers=containers)
runtime._verify_image_available = MagicMock()
runtime._find_available_port = MagicMock(side_effect=[12345, 12346])
runtime._wait_for_tool_server = MagicMock()
runtime._scan_container = None

runtime._create_container("scan-id")

assert run.call_args.kwargs["extra_hosts"] == {
HOST_GATEWAY_HOSTNAME: "host-gateway",
"test.internal.lan": "host-gateway",
}


def test_create_container_wraps_invalid_extra_hosts(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("STRIX_SANDBOX_EXTRA_HOSTS", "test.internal.lan")

run = MagicMock()
containers = SimpleNamespace(get=MagicMock(side_effect=NotFound("missing")), run=run)
runtime = DockerRuntime.__new__(DockerRuntime)
runtime.client = SimpleNamespace(containers=containers)
runtime._verify_image_available = MagicMock()
runtime._find_available_port = MagicMock(side_effect=[12345, 12346])
runtime._wait_for_tool_server = MagicMock()
runtime._scan_container = None

with pytest.raises(SandboxInitializationError, match="Invalid Docker sandbox host mapping"):
runtime._create_container("scan-id")

run.assert_not_called()