Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion app/schemas/tool/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from app.exceptions.exception import ValidateFailedError
from app.schemas.tool.authentication import Authentication, AuthenticationType
from app.services.tool.url_security import validate_openapi_server_url


# This function code from the Open Source Project TaskingAI.
Expand All @@ -31,6 +32,11 @@ def validate_openapi_schema(schema: Dict):
if len(schema["servers"]) != 1:
raise ValidateFailedError("Exactly one server is allowed in action schema")

server_url = schema["servers"][0].get("url") if isinstance(schema["servers"][0], dict) else None
if not server_url or not isinstance(server_url, str):
raise ValidateFailedError("Action schema server URL is required")
validate_openapi_server_url(server_url)

# check each path method has a valid description and operationId
for path, methods in schema["paths"].items():
for method, details in methods.items():
Expand All @@ -56,7 +62,7 @@ def validate_openapi_schema(schema: Dict):

if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", details["operationId"]):
raise ValidateFailedError(
f'Invalid operationId {details["operationId"]} in {method} {path} in action schema'
f"Invalid operationId {details['operationId']} in {method} {path} in action schema"
)

return schema
Expand Down
42 changes: 30 additions & 12 deletions app/services/tool/openapi_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from app.schemas.tool.authentication import Authentication, AuthenticationType
from app.schemas.tool.action import ActionMethod, ActionBodyType, ActionParam
from app.services.tool.url_security import UnsafeActionURLError, validate_action_url


# This function code from the Open Source Project TaskingAI.
Expand Down Expand Up @@ -139,18 +140,35 @@ def call_action_api(

logging.info(f"call_action_api url={url} request kwargs: {request_kwargs}")

with requests.request(method.value, url, **request_kwargs) as response:
response_content_type = response.headers.get("Content-Type", "").lower()
if "application/json" in response_content_type:
data = response.json()
else:
data = response.text
if response.status_code == 500:
error_message = f"API call failed with status {response.status_code}"
if data:
error_message += f": {data}"
return {"status": response.status_code, "error": error_message}
return {"status": response.status_code, "data": data}
timeout = float(os.environ.get("ACTION_HTTP_TIMEOUT", "10"))
request_kwargs["timeout"] = timeout
request_kwargs["allow_redirects"] = False

current_url = url
for _ in range(5):
validate_action_url(current_url)
with requests.request(method.value, current_url, **request_kwargs) as response:
if response.is_redirect:
redirect_url = response.headers.get("Location")
if not redirect_url:
return {"status": response.status_code, "error": "Redirect response missing Location header"}
current_url = urllib.parse.urljoin(current_url, redirect_url)
continue

response_content_type = response.headers.get("Content-Type", "").lower()
if "application/json" in response_content_type:
data = response.json()
else:
data = response.text
if response.status_code == 500:
error_message = f"API call failed with status {response.status_code}"
if data:
error_message += f": {data}"
return {"status": response.status_code, "error": error_message}
return {"status": response.status_code, "data": data}
return {"status": 500, "error": "Too many redirects while making the API call"}
except UnsafeActionURLError as e:
return {"status": 400, "error": f"Blocked unsafe action URL: {e}"}
except requests.exceptions.RequestException as e:
return {"status": 500, "error": f"Failed to make the API call: {e}"}
except Exception:
Expand Down
76 changes: 76 additions & 0 deletions app/services/tool/url_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import ipaddress
import socket
import urllib.parse
from typing import Iterable

from app.exceptions.exception import ValidateFailedError


_ALLOWED_ACTION_SCHEMES = {"http", "https"}
_BLOCKED_HOSTNAMES = {"localhost"}


class UnsafeActionURLError(ValueError):
pass


def _is_blocked_ip(address: str) -> bool:
ip = ipaddress.ip_address(address)
return any(
[
ip.is_private,
ip.is_loopback,
ip.is_link_local,
ip.is_multicast,
ip.is_reserved,
ip.is_unspecified,
]
)


def _resolved_addresses(hostname: str, port: int | None) -> Iterable[str]:
try:
addrinfos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise UnsafeActionURLError(f"Unable to resolve action URL host: {hostname}") from exc

addresses = set()
for addrinfo in addrinfos:
sockaddr = addrinfo[4]
if sockaddr:
addresses.add(sockaddr[0])
return addresses


def validate_action_url(url: str) -> None:
"""Validate that an action URL does not target local or private network addresses."""
parsed = urllib.parse.urlparse(url)
if parsed.scheme.lower() not in _ALLOWED_ACTION_SCHEMES:
raise UnsafeActionURLError("Action URL scheme must be http or https")

if not parsed.hostname:
raise UnsafeActionURLError("Action URL must include a hostname")

if parsed.username or parsed.password:
raise UnsafeActionURLError("Action URL must not include user credentials")

hostname = parsed.hostname.rstrip(".").lower()
if hostname in _BLOCKED_HOSTNAMES or hostname.endswith(".localhost"):
raise UnsafeActionURLError("Action URL host is not allowed")

try:
ip = ipaddress.ip_address(hostname)
except ValueError:
for address in _resolved_addresses(hostname, parsed.port):
if _is_blocked_ip(address):
raise UnsafeActionURLError("Action URL host resolves to a disallowed address")
else:
if _is_blocked_ip(str(ip)):
raise UnsafeActionURLError("Action URL host is not allowed")


def validate_openapi_server_url(url: str) -> None:
try:
validate_action_url(url)
except UnsafeActionURLError as exc:
raise ValidateFailedError(f"Invalid action server URL: {exc}") from exc
101 changes: 101 additions & 0 deletions tests/unit/test_action_url_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import socket

from app.exceptions.exception import ValidateFailedError
from app.schemas.tool.action import ActionBodyType, ActionMethod, validate_openapi_schema
from app.schemas.tool.authentication import Authentication, AuthenticationType
from app.services.tool.openapi_call import call_action_api
from app.services.tool.url_security import UnsafeActionURLError, validate_action_url


def _addrinfo(address):
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (address, 80))]


def _schema(server_url):
return {
"openapi": "3.0.0",
"info": {"title": "test", "version": "1.0"},
"servers": [{"url": server_url}],
"paths": {
"/status": {
"get": {
"operationId": "get_status",
"description": "read status",
"responses": {"200": {"description": "ok"}},
}
}
},
}


class FakeResponse:
def __init__(self, status_code=200, text="ok", headers=None):
self.status_code = status_code
self.text = text
self.headers = headers or {"Content-Type": "text/plain"}
self.is_redirect = status_code in {301, 302, 303, 307, 308}

def __enter__(self):
return self

def __exit__(self, exc_type, exc, tb):
return False

def json(self):
return {"ok": True}


def test_validate_action_url_blocks_loopback_ip():
try:
validate_action_url("http://127.0.0.1:8080/secret")
except UnsafeActionURLError:
return
raise AssertionError("loopback action URL was not blocked")


def test_validate_action_url_blocks_dns_to_private_address(monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", lambda *args, **kwargs: _addrinfo("10.0.0.5"))

try:
validate_action_url("https://api.example.test/status")
except UnsafeActionURLError:
return
raise AssertionError("private DNS target was not blocked")


def test_openapi_schema_rejects_private_server_url():
try:
validate_openapi_schema(_schema("http://169.254.169.254/latest"))
except ValidateFailedError:
return
raise AssertionError("OpenAPI server URL pointing at metadata service was not rejected")


def test_call_action_api_blocks_private_redirect(monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", lambda host, *args, **kwargs: _addrinfo("93.184.216.34"))

calls = []

def fake_request(method, url, **kwargs):
calls.append((method, url, kwargs))
return FakeResponse(302, headers={"Location": "http://127.0.0.1:8080/admin"})

monkeypatch.setattr("app.services.tool.openapi_call.requests.request", fake_request)

result = call_action_api(
url="https://api.example.test/status",
method=ActionMethod.GET,
path_param_schema={},
query_param_schema={},
body_type=ActionBodyType.NONE,
body_param_schema={},
parameters={},
headers={},
authentication=Authentication(type=AuthenticationType.none),
)

assert calls == [("GET", "https://api.example.test/status", calls[0][2])]
assert result["status"] == 400
assert "Blocked unsafe action URL" in result["error"]
assert calls[0][2]["allow_redirects"] is False
assert calls[0][2]["timeout"] == 10