Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions src/adcp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,43 @@ def load_payload(payload_arg: str | None) -> dict[str, Any]:
sys.exit(1)


def handle_save_auth(alias: str, url: str | None, protocol: str | None) -> None:
def merge_headers(saved: dict[str, str] | None, runtime: dict[str, str]) -> dict[str, str]:
"""Merge runtime --header flags over saved-config headers; runtime wins."""
merged = dict(saved or {})
merged.update(runtime)
return merged


def parse_header_args(header_args: list[str] | None) -> dict[str, str]:
"""Parse repeated ``--header KEY=VALUE`` flags into a dict.

Exits with an error message on malformed input.
"""
if not header_args:
return {}
headers: dict[str, str] = {}
for raw in header_args:
if "=" not in raw:
print(
f"Error: --header expects KEY=VALUE, got: {raw!r}",
file=sys.stderr,
)
sys.exit(2)
key, _, value = raw.partition("=")
key = key.strip()
if not key:
print(f"Error: --header has empty key: {raw!r}", file=sys.stderr)
sys.exit(2)
headers[key] = value
return headers


def handle_save_auth(
alias: str,
url: str | None,
protocol: str | None,
extra_headers: dict[str, str] | None = None,
) -> None:
"""Handle --save-auth command."""
if not url:
# Interactive mode
Expand All @@ -438,7 +474,7 @@ def handle_save_auth(alias: str, url: str | None, protocol: str | None) -> None:

auth_token = input("Auth token (optional): ").strip() or None

save_agent(alias, url, protocol, auth_token)
save_agent(alias, url, protocol, auth_token, extra_headers=extra_headers or None)
print(f"✓ Saved agent '{alias}'")


Expand All @@ -457,6 +493,10 @@ def handle_list_agents() -> None:
print(f" URL: {config.get('agent_uri')}")
print(f" Protocol: {config.get('protocol', 'mcp').upper()}")
print(f" Auth: {auth}")
extra_headers = config.get("extra_headers") or {}
if extra_headers:
keys = ", ".join(sorted(extra_headers))
print(f" Headers: {keys}")


def handle_remove_agent(alias: str) -> None:
Expand Down Expand Up @@ -612,6 +652,15 @@ def main() -> None:
# Execution options
parser.add_argument("--protocol", choices=["mcp", "a2a"], help="Force protocol type")
parser.add_argument("--auth", help="Authentication token")
parser.add_argument(
"--header",
"-H",
action="append",
metavar="KEY=VALUE",
help="Additional HTTP header sent on every request (repeatable). "
"Example: -H x-adcp-tenant=acme. With --save-auth, persists into the "
"saved agent config.",
)
parser.add_argument("--json", action="store_true", help="Output as JSON")
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
parser.add_argument("--help", "-h", action="store_true", help="Show help")
Expand Down Expand Up @@ -671,7 +720,8 @@ def main() -> None:
if args.save_auth:
url = args.agent if args.agent else None
protocol = args.tool if args.tool else None
handle_save_auth(args.save_auth, url, protocol)
cli_headers = parse_header_args(args.header)
handle_save_auth(args.save_auth, url, protocol, extra_headers=cli_headers)
sys.exit(0)

if args.list_agents:
Expand Down Expand Up @@ -709,6 +759,12 @@ def main() -> None:
if args.auth:
agent_config["auth_token"] = args.auth

cli_headers = parse_header_args(args.header)
if cli_headers:
agent_config["extra_headers"] = merge_headers(
agent_config.get("extra_headers"), cli_headers
)

if args.debug:
agent_config["debug"] = True

Expand Down
9 changes: 8 additions & 1 deletion src/adcp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def save_config(config: dict[str, Any]) -> None:


def save_agent(
alias: str, url: str, protocol: str | None = None, auth_token: str | None = None
alias: str,
url: str,
protocol: str | None = None,
auth_token: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> None:
"""Save agent configuration."""
config = load_config()
Expand All @@ -54,6 +58,9 @@ def save_agent(
if auth_token:
config["agents"][alias]["auth_token"] = auth_token

if extra_headers:
config["agents"][alias]["extra_headers"] = dict(extra_headers)

save_config(config)


Expand Down
3 changes: 3 additions & 0 deletions src/adcp/protocols/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ async def _get_httpx_client(self) -> httpx.AsyncClient:
else:
headers[self.agent_config.auth_header] = self.agent_config.auth_token

if self.agent_config.extra_headers:
headers.update(self.agent_config.extra_headers)

# When ADCPClient installed a signing_request_hook, register it as
# an httpx request event hook so RFC 9421 signature headers are
# attached transparently to every outgoing request. The hook is
Expand Down
3 changes: 3 additions & 0 deletions src/adcp/protocols/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ async def _get_session(self) -> ClientSession:
else:
headers[self.agent_config.auth_header] = self.agent_config.auth_token

if self.agent_config.extra_headers:
headers.update(self.agent_config.extra_headers)

# Try the user's exact URL first
urls_to_try = [self.agent_config.agent_uri]

Expand Down
40 changes: 39 additions & 1 deletion src/adcp/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from typing import Any, Generic, Literal, TypeVar

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator


class Protocol(str, Enum):
Expand All @@ -30,6 +30,23 @@ class AgentConfig(BaseModel):
"streamable_http" # "streamable_http" (default, modern) or "sse" (legacy fallback)
)
debug: bool = False # Enable debug mode to capture request/response details
extra_headers: dict[str, str] = Field(default_factory=dict)
"""Additional HTTP headers sent on every request to this agent.

This is a **transport-layer escape hatch**, not an AdCP protocol
extension point — protocol-defined fields belong in the request
envelope or ``RequestContext.metadata``. Use this for vendor or
deployment-specific routing headers (e.g. tenant routing on a
multi-tenant server).

Reserved: the configured ``auth_header`` (default ``x-adcp-auth``)
and the standard ``Authorization`` header — set credentials via
``auth_token``/``auth_header`` instead. Header names are rejected
if they contain CR/LF or other control characters.

Persisted plaintext at ``~/.adcp/config.json`` when saved via the
CLI — do not store credentials here.
"""

@field_validator("agent_uri")
@classmethod
Expand Down Expand Up @@ -86,6 +103,27 @@ def validate_auth_type(cls, v: str) -> str:
)
return v

@model_validator(mode="after")
def _validate_extra_headers(self) -> AgentConfig:
if not self.extra_headers:
return self
reserved = {self.auth_header.lower(), "authorization"}
for key, value in self.extra_headers.items():
if not key:
raise ValueError("extra_headers contains an empty header name")
if any(c in key for c in ("\r", "\n", "\x00")) or any(ord(c) < 0x20 for c in key):
raise ValueError(f"extra_headers key contains control character: {key!r}")
if any(c in value for c in ("\r", "\n", "\x00")):
raise ValueError(f"extra_headers value for {key!r} contains CR/LF/NUL")
if key.lower() in reserved:
raise ValueError(
f"extra_headers may not override reserved auth header "
f"{key!r} (collides with auth_header={self.auth_header!r} "
f"or 'Authorization'); set credentials via auth_token + "
f"auth_header instead"
)
return self


class TaskStatus(str, Enum):
"""Task execution status."""
Expand Down
84 changes: 83 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

import pytest

from adcp.__main__ import load_payload, resolve_agent_config
from adcp.__main__ import (
load_payload,
merge_headers,
parse_header_args,
resolve_agent_config,
)
from adcp.config import save_agent


Expand Down Expand Up @@ -190,6 +195,26 @@ def test_save_agent_command(self, tmp_path, monkeypatch):
assert config["agents"]["test_agent"]["agent_uri"] == "https://test.com"
assert config["agents"]["test_agent"]["auth_token"] == "secret_token"

def test_save_agent_persists_extra_headers(self, tmp_path, monkeypatch):
"""save_agent writes extra_headers into the saved config."""
config_file = tmp_path / "config.json"
config_file.write_text(json.dumps({"agents": {}}))

import adcp.config

monkeypatch.setattr(adcp.config, "CONFIG_FILE", config_file)

save_agent(
"tenant_agent",
"https://test.com",
"mcp",
"secret_token",
extra_headers={"x-adcp-tenant": "acme"},
)

config = json.loads(config_file.read_text())
assert config["agents"]["tenant_agent"]["extra_headers"] == {"x-adcp-tenant": "acme"}

def test_list_agents_command(self, tmp_path, monkeypatch):
"""Test --list-agents shows saved agents."""
config_file = tmp_path / "config.json"
Expand Down Expand Up @@ -240,6 +265,63 @@ def test_show_config_command(self):
assert ".adcp" in result.stdout or "config.json" in result.stdout


class TestHeaderArgParsing:
"""Test --header KEY=VALUE flag parsing."""

def test_returns_empty_for_none(self):
assert parse_header_args(None) == {}

def test_returns_empty_for_empty_list(self):
assert parse_header_args([]) == {}

def test_parses_single_header(self):
assert parse_header_args(["x-adcp-tenant=acme"]) == {"x-adcp-tenant": "acme"}

def test_parses_multiple_headers(self):
result = parse_header_args(["x-adcp-tenant=acme", "x-correlation-id=req-1"])
assert result == {"x-adcp-tenant": "acme", "x-correlation-id": "req-1"}

def test_value_may_contain_equals(self):
result = parse_header_args(["x-token=a=b=c"])
assert result == {"x-token": "a=b=c"}

def test_strips_key_whitespace(self):
result = parse_header_args([" x-adcp-tenant =acme"])
assert result == {"x-adcp-tenant": "acme"}

def test_missing_equals_exits(self):
with pytest.raises(SystemExit) as exc_info:
parse_header_args(["no-equals-here"])
assert exc_info.value.code == 2

def test_empty_key_exits(self):
with pytest.raises(SystemExit) as exc_info:
parse_header_args(["=value"])
assert exc_info.value.code == 2


class TestHeaderMerge:
"""Test merge precedence between saved-config and runtime --header flags."""

def test_runtime_wins_on_collision(self):
result = merge_headers(
{"x-adcp-tenant": "old", "x-trace-id": "abc"},
{"x-adcp-tenant": "new"},
)
assert result == {"x-adcp-tenant": "new", "x-trace-id": "abc"}

def test_saved_only(self):
result = merge_headers({"x-adcp-tenant": "acme"}, {})
assert result == {"x-adcp-tenant": "acme"}

def test_runtime_only(self):
result = merge_headers(None, {"x-adcp-tenant": "acme"})
assert result == {"x-adcp-tenant": "acme"}

def test_both_empty(self):
assert merge_headers(None, {}) == {}


class TestCLIErrorHandling:
"""Test error handling in CLI."""

Expand Down
73 changes: 73 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,79 @@ def test_agent_config_creation():
assert config.protocol == Protocol.A2A


def test_agent_config_extra_headers_default_empty():
config = AgentConfig(
id="test_agent",
agent_uri="https://test.example.com",
protocol=Protocol.MCP,
)
assert config.extra_headers == {}


def test_agent_config_extra_headers_accepted():
config = AgentConfig(
id="test_agent",
agent_uri="https://test.example.com",
protocol=Protocol.MCP,
extra_headers={"x-adcp-tenant": "acme", "x-correlation-id": "req-1"},
)
assert config.extra_headers == {
"x-adcp-tenant": "acme",
"x-correlation-id": "req-1",
}


def test_agent_config_extra_headers_rejects_auth_header_collision():
with pytest.raises(ValueError, match="reserved auth header"):
AgentConfig(
id="test_agent",
agent_uri="https://test.example.com",
protocol=Protocol.MCP,
auth_header="x-custom-auth",
extra_headers={"X-Custom-Auth": "tok"}, # case-insensitive collision
)


def test_agent_config_extra_headers_rejects_authorization_collision():
with pytest.raises(ValueError, match="reserved auth header"):
AgentConfig(
id="test_agent",
agent_uri="https://test.example.com",
protocol=Protocol.MCP,
extra_headers={"Authorization": "Bearer foo"},
)


def test_agent_config_extra_headers_rejects_empty_key():
with pytest.raises(ValueError, match="empty header name"):
AgentConfig(
id="test_agent",
agent_uri="https://test.example.com",
protocol=Protocol.MCP,
extra_headers={"": "value"},
)


def test_agent_config_extra_headers_rejects_crlf_in_value():
with pytest.raises(ValueError, match="CR/LF/NUL"):
AgentConfig(
id="test_agent",
agent_uri="https://test.example.com",
protocol=Protocol.MCP,
extra_headers={"x-trace": "a\r\nAuthorization: Bearer evil"},
)


def test_agent_config_extra_headers_rejects_crlf_in_key():
with pytest.raises(ValueError, match="control character"):
AgentConfig(
id="test_agent",
agent_uri="https://test.example.com",
protocol=Protocol.MCP,
extra_headers={"x-trace\nInjected": "value"},
)


def test_client_creation():
"""Test creating ADCP client."""
config = AgentConfig(
Expand Down
Loading
Loading