Skip to content

Commit 9fbbf72

Browse files
feat: propagate trace context and log correlation (#397)
1 parent ccf6d80 commit 9fbbf72

16 files changed

Lines changed: 534 additions & 76 deletions

docs/guide.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ Key variables to understand protocol behavior:
3737
- `A2A_PUBLIC_URL`: public base URL advertised by the Agent Card. Default: `http://127.0.0.1:8000`.
3838
- `A2A_LOG_LEVEL`: runtime log level. Default: `WARNING`.
3939
- `A2A_LOG_PAYLOADS` / `A2A_LOG_BODY_LIMIT`: payload logging behavior and truncation. When `A2A_LOG_LEVEL=DEBUG`, upstream OpenCode stream events are also logged with preview truncation controlled by `A2A_LOG_BODY_LIMIT`.
40+
- The runtime accepts W3C `traceparent` / `tracestate` headers on inbound requests. When `traceparent` is missing or invalid, the runtime generates a fresh valid value and exposes it on the HTTP response header.
41+
- The active `traceparent` / `tracestate` pair is propagated across inbound A2A handling, OpenCode upstream requests, and outbound peer A2A calls triggered through the embedded client or `a2a_call` tool path.
42+
- Logs derive a stable `trace_id` from the active `traceparent` so request-scoped log lines can be correlated without introducing high-cardinality metric labels.
4043
- `A2A_HTTP_GZIP_MINIMUM_SIZE`: minimum eligible response-body size in bytes for global non-streaming HTTP gzip compression. Default: `8192`.
4144
- `A2A_MAX_REQUEST_BODY_BYTES`: runtime request-body limit. Oversized requests return HTTP `413`.
4245
- `A2A_PENDING_SESSION_CLAIM_TTL_SECONDS`: lease duration for pending preferred session claims before they expire and stop blocking other identities.

src/opencode_a2a/client/agent_card.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
PREV_AGENT_CARD_WELL_KNOWN_PATH,
1414
)
1515

16+
from ..trace_context import current_trace_headers
1617
from .request_context import build_default_headers
1718

1819

@@ -70,6 +71,9 @@ def build_resolver_http_kwargs(
7071
) -> dict[str, Any]:
7172
http_kwargs: dict[str, Any] = {"timeout": timeout}
7273
default_headers = build_default_headers(bearer_token, basic_auth)
74+
trace_headers = current_trace_headers()
75+
if trace_headers:
76+
default_headers.update(trace_headers)
7377
if default_headers:
7478
http_kwargs["headers"] = default_headers
7579
return http_kwargs

src/opencode_a2a/client/client.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .errors import A2ATimeoutError, A2AUnsupportedBindingError
3636
from .payload_text import extract_text as extract_text_from_payload
3737
from .polling import PollingFallbackPolicy
38-
from .request_context import build_call_context, build_client_interceptors, split_request_metadata
38+
from .request_context import build_call_context, split_request_metadata
3939

4040

4141
class A2AClient:
@@ -302,14 +302,7 @@ async def _build_client(self) -> Client:
302302
)
303303
try:
304304
factory = ClientFactory(config, consumers=None)
305-
client = factory.create(
306-
card,
307-
interceptors=build_client_interceptors(
308-
self._settings.bearer_token,
309-
self._settings.basic_auth,
310-
self._settings.protocol_version,
311-
),
312-
)
305+
client = factory.create(card)
313306
except ValueError as exc:
314307
raise A2AUnsupportedBindingError(
315308
f"No supported transport found for {self.agent_url}"

src/opencode_a2a/client/request_context.py

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,13 @@
55
from collections.abc import Mapping
66
from typing import Any
77

8-
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
8+
from a2a.client.middleware import ClientCallContext
99

1010
from ..protocol_versions import normalize_protocol_version
11+
from ..trace_context import current_trace_headers
1112
from .auth import encode_basic_auth
1213

1314

14-
class HeaderInterceptor(ClientCallInterceptor):
15-
def __init__(self, default_headers: Mapping[str, str] | None = None) -> None:
16-
self._default_headers = {
17-
key: value for key, value in dict(default_headers or {}).items() if value is not None
18-
}
19-
20-
async def intercept(
21-
self,
22-
method_name: str,
23-
request_payload: dict[str, Any],
24-
http_kwargs: dict[str, Any],
25-
agent_card: object | None,
26-
context: ClientCallContext | None,
27-
) -> tuple[dict[str, Any], dict[str, Any]]:
28-
del method_name, agent_card
29-
headers = dict(http_kwargs.get("headers") or {})
30-
headers.update(self._default_headers)
31-
if context is not None:
32-
dynamic_headers = context.state.get("headers")
33-
if isinstance(dynamic_headers, Mapping):
34-
for key, value in dynamic_headers.items():
35-
if isinstance(key, str) and value is not None:
36-
headers[key] = str(value)
37-
if headers:
38-
http_kwargs["headers"] = headers
39-
return request_payload, http_kwargs
40-
41-
4215
def build_default_headers(
4316
bearer_token: str | None,
4417
basic_auth: str | None = None,
@@ -68,6 +41,14 @@ def split_request_metadata(
6841
if value is not None:
6942
extra_headers["A2A-Version"] = normalize_protocol_version(str(value))
7043
continue
44+
if isinstance(key, str) and key.lower() == "traceparent":
45+
if value is not None:
46+
extra_headers["traceparent"] = str(value)
47+
continue
48+
if isinstance(key, str) and key.lower() == "tracestate":
49+
if value is not None:
50+
extra_headers["tracestate"] = str(value)
51+
continue
7152
request_metadata[key] = value
7253
return request_metadata or None, extra_headers or None
7354

@@ -79,6 +60,7 @@ def build_call_context(
7960
protocol_version: str | None = None,
8061
) -> ClientCallContext | None:
8162
merged_headers = build_default_headers(bearer_token, basic_auth, protocol_version)
63+
merged_headers.update(current_trace_headers())
8264
if extra_headers:
8365
merged_headers.update(extra_headers)
8466
if not merged_headers:
@@ -91,18 +73,8 @@ def build_call_context(
9173
)
9274

9375

94-
def build_client_interceptors(
95-
bearer_token: str | None,
96-
basic_auth: str | None = None,
97-
protocol_version: str | None = None,
98-
) -> list[ClientCallInterceptor]:
99-
return [HeaderInterceptor(build_default_headers(bearer_token, basic_auth, protocol_version))]
100-
101-
10276
__all__ = [
103-
"HeaderInterceptor",
10477
"build_call_context",
105-
"build_client_interceptors",
10678
"build_default_headers",
10779
"split_request_metadata",
10880
]

src/opencode_a2a/execution/executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
214214
identity = (call_context.state.get("identity") if call_context else None) or "anonymous"
215215
credential_id = call_context.state.get("credential_id") if call_context else None
216216
auth_scheme = call_context.state.get("auth_scheme") if call_context else None
217+
trace_id = call_context.state.get("trace_id") if call_context else None
217218

218219
streaming_request = self._should_stream(context)
219220
accepted_output_modes = normalize_accepted_output_modes(context.configuration)
@@ -313,13 +314,14 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
313314

314315
logger.debug(
315316
(
316-
"Received message identity=%s credential_id=%s auth_scheme=%s "
317+
"Received message identity=%s credential_id=%s auth_scheme=%s trace_id=%s "
317318
"task_id=%s context_id=%s "
318319
"streaming=%s text=%s part_count=%s"
319320
),
320321
identity,
321322
credential_id,
322323
auth_scheme,
324+
trace_id,
323325
task_id,
324326
context_id,
325327
streaming_request,

src/opencode_a2a/opencode_upstream_client.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .parts.text import extract_text_from_parts
1616
from .runtime_state import InterruptRequestBinding
1717
from .server.state_store import InterruptRequestRepository, MemoryInterruptRequestRepository
18+
from .trace_context import current_trace_headers
1819

1920
_UNSET = object()
2021
logger = logging.getLogger(__name__)
@@ -143,6 +144,13 @@ def _build_http_client(self, base_url: str) -> httpx.AsyncClient:
143144
headers={"Accept": "application/json"},
144145
)
145146

147+
@staticmethod
148+
def _request_headers(headers: Mapping[str, str] | None = None) -> dict[str, str] | None:
149+
merged_headers = current_trace_headers()
150+
if headers:
151+
merged_headers.update(dict(headers))
152+
return merged_headers or None
153+
146154
async def close(self) -> None:
147155
await self._client.aclose()
148156

@@ -185,7 +193,11 @@ async def _get_json(
185193
params: Mapping[str, Any] | None = None,
186194
) -> Any:
187195
async with self._request_budget.reserve(operation=endpoint):
188-
response = await self._client.get(path, params=params)
196+
response = await self._client.get(
197+
path,
198+
params=params,
199+
headers=self._request_headers(),
200+
)
189201
response.raise_for_status()
190202
return self._decode_json_response(response, endpoint=endpoint)
191203

@@ -207,6 +219,7 @@ async def _post_json(
207219
response = await self._client.post(
208220
path,
209221
params=params,
222+
headers=self._request_headers(),
210223
**request_kwargs,
211224
)
212225
response.raise_for_status()
@@ -238,7 +251,11 @@ async def _delete_json(
238251
params: Mapping[str, Any] | None = None,
239252
) -> Any:
240253
async with self._request_budget.reserve(operation=endpoint):
241-
response = await self._client.delete(path, params=params)
254+
response = await self._client.delete(
255+
path,
256+
params=params,
257+
headers=self._request_headers(),
258+
)
242259
response.raise_for_status()
243260
return self._decode_json_response(response, endpoint=endpoint)
244261

@@ -402,7 +419,7 @@ async def stream_events(
402419
"/event",
403420
params=params,
404421
timeout=None,
405-
headers={"Accept": "text/event-stream"},
422+
headers=self._request_headers({"Accept": "text/event-stream"}),
406423
) as response:
407424
response.raise_for_status()
408425
data_lines: list[str] = []
@@ -555,6 +572,7 @@ async def list_messages(
555572
response = await self._client.get(
556573
f"/session/{session_id}/message",
557574
params=self._merge_params(params, workspace_id=workspace_id),
575+
headers=self._request_headers(),
558576
)
559577
response.raise_for_status()
560578
payload = self._decode_json_response(response, endpoint=endpoint)
@@ -594,6 +612,7 @@ async def session_prompt_async(
594612
f"/session/{session_id}/prompt_async",
595613
params=self._query_params(directory=directory, workspace_id=workspace_id),
596614
json=request,
615+
headers=self._request_headers(),
597616
)
598617
response.raise_for_status()
599618
if response.status_code != 204:
@@ -748,7 +767,10 @@ async def create_workspace(self, request: dict[str, Any]) -> Any:
748767

749768
async def remove_workspace(self, workspace_id: str) -> Any:
750769
async with self._request_budget.reserve(operation="/experimental/workspace/{id}"):
751-
response = await self._client.delete(f"/experimental/workspace/{workspace_id}")
770+
response = await self._client.delete(
771+
f"/experimental/workspace/{workspace_id}",
772+
headers=self._request_headers(),
773+
)
752774
response.raise_for_status()
753775
return self._decode_json_response(
754776
response,
@@ -774,6 +796,7 @@ async def remove_worktree(self, request: dict[str, Any]) -> bool:
774796
"DELETE",
775797
"/experimental/worktree",
776798
json=request,
799+
headers=self._request_headers(),
777800
)
778801
response.raise_for_status()
779802
payload = self._decode_json_response(response, endpoint="/experimental/worktree")

src/opencode_a2a/server/application.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from ..opencode_upstream_client import OpencodeUpstreamClient
6767
from ..output_modes import normalize_accepted_output_modes
6868
from ..profile.runtime import build_runtime_profile
69+
from ..trace_context import install_log_record_factory
6970
from .agent_card import (
7071
_CHAT_OUTPUT_MODES,
7172
_build_agent_card_description,
@@ -514,6 +515,15 @@ def build(self, request: Request) -> ServerCallContext:
514515
credential_id = getattr(request.state, "user_credential_id", None)
515516
if credential_id:
516517
context.state["credential_id"] = credential_id
518+
traceparent = getattr(request.state, "traceparent", None)
519+
if traceparent:
520+
context.state["traceparent"] = traceparent
521+
tracestate = getattr(request.state, "tracestate", None)
522+
if tracestate:
523+
context.state["tracestate"] = tracestate
524+
trace_id = getattr(request.state, "trace_id", None)
525+
if trace_id:
526+
context.state["trace_id"] = trace_id
517527
negotiated_protocol_version = getattr(request.state, "a2a_protocol_version", None)
518528
if negotiated_protocol_version:
519529
context.state["a2a_protocol_version"] = negotiated_protocol_version
@@ -525,6 +535,7 @@ def build(self, request: Request) -> ServerCallContext:
525535

526536

527537
def create_app(settings: Settings) -> FastAPI:
538+
install_log_record_factory()
528539
database_engine = (
529540
build_database_engine(settings) if settings.a2a_task_store_backend == "database" else None
530541
)
@@ -663,9 +674,10 @@ def _normalize_log_level(value: str) -> str:
663674

664675

665676
def _configure_logging(level: str) -> None:
677+
install_log_record_factory()
666678
logging.basicConfig(
667679
level=getattr(logging, level, logging.INFO),
668-
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
680+
format="%(asctime)s %(levelname)s %(name)s [trace_id=%(trace_id)s]: %(message)s",
669681
)
670682
logging.getLogger("uvicorn.error").setLevel(level)
671683
logging.getLogger("uvicorn.access").setLevel(level)

src/opencode_a2a/server/middleware.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@
3030
negotiate_protocol_version,
3131
normalize_protocol_version,
3232
)
33+
from ..trace_context import (
34+
TRACEPARENT_HEADER,
35+
TRACESTATE_HEADER,
36+
reset_current_trace_context,
37+
resolve_trace_context,
38+
set_current_trace_context,
39+
)
3340
from .request_parsing import (
3441
_decode_payload_preview,
3542
_detect_sensitive_extension_method,
@@ -158,6 +165,24 @@ def _uses_v1_jsonrpc_aliases(request: Request) -> bool:
158165
except ValueError:
159166
return False
160167

168+
@app.middleware("http")
169+
async def bind_trace_context(request: Request, call_next):
170+
trace_context = resolve_trace_context(
171+
request.headers.get(TRACEPARENT_HEADER),
172+
request.headers.get(TRACESTATE_HEADER),
173+
)
174+
request.state.traceparent = trace_context.traceparent
175+
request.state.trace_id = trace_context.trace_id
176+
if trace_context.tracestate:
177+
request.state.tracestate = trace_context.tracestate
178+
token = set_current_trace_context(trace_context)
179+
try:
180+
response = await call_next(request)
181+
finally:
182+
reset_current_trace_context(token)
183+
response.headers[TRACEPARENT_HEADER] = trace_context.traceparent
184+
return response
185+
161186
@app.middleware("http")
162187
async def negotiate_a2a_protocol_version(request: Request, call_next):
163188
token: Token | None = None

0 commit comments

Comments
 (0)