Skip to content

Commit ac2f20b

Browse files
committed
clean up
1 parent dfe5e36 commit ac2f20b

3 files changed

Lines changed: 149 additions & 40 deletions

File tree

py/src/braintrust/wrappers/claude_agent_sdk/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,17 @@ def setup_claude_agent_sdk(
7171

7272
import claude_agent_sdk
7373

74+
def _module_references_any(module: object, references: list[tuple[str, object | None]]) -> bool:
75+
for attr_name, original_value in references:
76+
if original_value is not None and getattr(module, attr_name, None) is original_value:
77+
return True
78+
return False
79+
7480
original_client = claude_agent_sdk.ClaudeSDKClient if hasattr(claude_agent_sdk, "ClaudeSDKClient") else None
7581
original_query_fn = claude_agent_sdk.query if hasattr(claude_agent_sdk, "query") else None
7682
original_tool_class = claude_agent_sdk.SdkMcpTool if hasattr(claude_agent_sdk, "SdkMcpTool") else None
7783
original_tool_fn = claude_agent_sdk.tool if hasattr(claude_agent_sdk, "tool") else None
84+
original_options_class = claude_agent_sdk.ClaudeAgentOptions if hasattr(claude_agent_sdk, "ClaudeAgentOptions") else None
7885

7986
wrapped_client = None
8087
if original_client:
@@ -90,9 +97,15 @@ def setup_claude_agent_sdk(
9097
wrapped_query_fn = _wrap_query_function(original_query_fn, wrapped_client)
9198
claude_agent_sdk.query = wrapped_query_fn
9299

100+
query_patch_anchors = [
101+
("ClaudeSDKClient", original_client),
102+
("ClaudeAgentOptions", original_options_class),
103+
("SdkMcpTool", original_tool_class),
104+
("tool", original_tool_fn),
105+
]
93106
for module in list(sys.modules.values()):
94-
if module and hasattr(module, "query"):
95-
if getattr(module, "query", None) is original_query_fn:
107+
if module and getattr(module, "query", None) is original_query_fn:
108+
if _module_references_any(module, query_patch_anchors):
96109
setattr(module, "query", wrapped_query_fn)
97110

98111
if original_tool_class:

py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import logging
44
import threading
55
import time
6-
from collections.abc import AsyncGenerator, AsyncIterable, Mapping
6+
from collections.abc import AsyncGenerator, AsyncIterable
77
from typing import Any
88

9+
from braintrust.bt_json import bt_safe_deep_copy
910
from braintrust.logger import start_span
1011
from braintrust.span_types import SpanTypeAttribute
1112
from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
@@ -165,32 +166,8 @@ def _serialize_system_message(message: Any) -> dict[str, Any]:
165166

166167
return serialized
167168

168-
169-
def _serialize_hook_value(value: Any) -> Any:
170-
if value is None or isinstance(value, (bool, int, float, str)):
171-
return value
172-
173-
if dataclasses.is_dataclass(value):
174-
return _serialize_hook_value(dataclasses.asdict(value))
175-
176-
if isinstance(value, Mapping):
177-
return {str(key): _serialize_hook_value(item) for key, item in value.items()}
178-
179-
if isinstance(value, (list, tuple)):
180-
return [_serialize_hook_value(item) for item in value]
181-
182-
if hasattr(value, "__dict__"):
183-
return {
184-
key: _serialize_hook_value(item)
185-
for key, item in vars(value).items()
186-
if not key.startswith("_") and not callable(item)
187-
}
188-
189-
return str(value)
190-
191-
192169
def _serialize_hook_context(context: Any) -> dict[str, Any] | None:
193-
serialized = _serialize_hook_value(context)
170+
serialized = bt_safe_deep_copy(context)
194171
if not isinstance(serialized, dict):
195172
return None
196173

@@ -224,7 +201,7 @@ async def wrapped_hook(*args: Any, **kwargs: Any) -> Any:
224201
tool_use_id = kwargs.get("tool_use_id") if "tool_use_id" in kwargs else (args[1] if len(args) > 1 else None)
225202
context = kwargs.get("context") if "context" in kwargs else (args[2] if len(args) > 2 else None)
226203

227-
span_input = {"input": _serialize_hook_value(hook_input)}
204+
span_input = {"input": bt_safe_deep_copy(hook_input)}
228205
if tool_use_id is not None:
229206
span_input["tool_use_id"] = str(tool_use_id)
230207

@@ -249,7 +226,7 @@ async def wrapped_hook(*args: Any, **kwargs: Any) -> Any:
249226
result = callback(*args, **kwargs)
250227
if inspect.isawaitable(result):
251228
result = await result
252-
span.log(output=_serialize_hook_value(result))
229+
span.log(output=bt_safe_deep_copy(result))
253230
return result
254231

255232
wrapped_hook._braintrust_wrapped_claude_hook = True # type: ignore[attr-defined]
@@ -281,6 +258,8 @@ def _wrap_options_hooks(options: Any) -> None:
281258
hooks_by_event = getattr(options, "hooks", None)
282259
if not isinstance(hooks_by_event, dict):
283260
return
261+
if getattr(options, "_braintrust_wrapped_claude_hooks_ref", None) is hooks_by_event:
262+
return
284263

285264
for event_name, matchers in list(hooks_by_event.items()):
286265
if not isinstance(matchers, list):
@@ -295,6 +274,11 @@ def _wrap_options_hooks(options: Any) -> None:
295274
except Exception:
296275
continue
297276

277+
try:
278+
setattr(options, "_braintrust_wrapped_claude_hooks_ref", hooks_by_event)
279+
except Exception:
280+
pass
281+
298282

299283
def _wrap_client_hooks(client: Any) -> None:
300284
_wrap_options_hooks(getattr(client, "options", None))
@@ -406,9 +390,7 @@ async def _trace_message_stream(
406390
task_events.append(_serialize_system_message(message))
407391

408392
yield message
409-
except Exception:
410-
raise
411-
else:
393+
412394
if final_results:
413395
span.log(output=final_results[-1])
414396
finally:
@@ -892,6 +874,7 @@ async def query(self, *args: Any, **kwargs: Any) -> Any:
892874
"""Wrap query to capture the prompt and start time for tracing."""
893875
# Capture the time when query is called (when LLM call starts)
894876
self.__query_start_time = time.time()
877+
# Re-wrap if client.options.hooks was replaced after client construction.
895878
_wrap_client_hooks(self.__client)
896879

897880
# Capture the prompt for use in receive_response
@@ -936,16 +919,13 @@ def _wrap_query_function(query_fn: Any, client_class: Any) -> Any:
936919
return query_fn
937920

938921
def wrapped_query(*args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
939-
prompt = kwargs.get("prompt")
940-
if prompt is None and args:
941-
prompt = args[0]
942-
943-
options = kwargs.get("options")
944-
transport = kwargs.get("transport")
922+
query_kwargs = dict(kwargs)
923+
options = query_kwargs.pop("options", None)
924+
transport = query_kwargs.pop("transport", None)
945925

946926
async def traced_generator() -> AsyncGenerator[Any, None]:
947927
async with client_class(options=options, transport=transport) as client:
948-
await client.query(prompt)
928+
await client.query(*args, **query_kwargs)
949929
async for message in client.receive_response():
950930
yield message
951931

py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636
_extract_usage_from_result_message,
3737
_parse_tool_name,
3838
_serialize_content_blocks,
39+
_serialize_hook_context,
3940
_serialize_system_message,
4041
_serialize_tool_result_output,
4142
_thread_local,
43+
_wrap_client_hooks,
4244
_wrap_query_function,
4345
)
4446
from braintrust.wrappers.test_utils import verify_autoinstrument_script
@@ -81,6 +83,120 @@ def _patched_claude_sdk(*, wrap_client: bool = False, wrap_tool_class: bool = Fa
8183
claude_agent_sdk.tool = original_tool_fn
8284

8385

86+
def test_wrap_client_hooks_fast_exits_for_same_hooks_object():
87+
class FakeOptions:
88+
def __init__(self, hooks: dict[str, list[Any]]):
89+
self.hooks = hooks
90+
91+
class FakeClient:
92+
def __init__(self, options: FakeOptions):
93+
self.options = options
94+
95+
def callback(value: Any) -> Any:
96+
return value
97+
98+
matcher = types.SimpleNamespace(hooks=[callback], matcher="tool")
99+
hooks = {"PreToolUse": [matcher]}
100+
client = FakeClient(FakeOptions(hooks))
101+
102+
_wrap_client_hooks(client)
103+
wrapped_callback = matcher.hooks[0]
104+
wrapped_ref = client.options._braintrust_wrapped_claude_hooks_ref
105+
106+
_wrap_client_hooks(client)
107+
108+
assert matcher.hooks[0] is wrapped_callback
109+
assert client.options._braintrust_wrapped_claude_hooks_ref is wrapped_ref
110+
111+
112+
def test_wrap_client_hooks_rewraps_when_hooks_container_is_replaced():
113+
class FakeOptions:
114+
def __init__(self, hooks: dict[str, list[Any]]):
115+
self.hooks = hooks
116+
117+
class FakeClient:
118+
def __init__(self, options: FakeOptions):
119+
self.options = options
120+
121+
def first_callback(value: Any) -> Any:
122+
return value
123+
124+
def second_callback(value: Any) -> Any:
125+
return value
126+
127+
first_matcher = types.SimpleNamespace(hooks=[first_callback], matcher="tool")
128+
client = FakeClient(FakeOptions({"PreToolUse": [first_matcher]}))
129+
130+
_wrap_client_hooks(client)
131+
assert hasattr(first_matcher.hooks[0], "_braintrust_wrapped_claude_hook")
132+
133+
second_matcher = types.SimpleNamespace(hooks=[second_callback], matcher="tool")
134+
client.options.hooks = {"PreToolUse": [second_matcher]}
135+
136+
_wrap_client_hooks(client)
137+
138+
assert hasattr(second_matcher.hooks[0], "_braintrust_wrapped_claude_hook")
139+
assert client.options._braintrust_wrapped_claude_hooks_ref is client.options.hooks
140+
141+
142+
@pytest.mark.asyncio
143+
async def test_wrap_query_function_forwards_unknown_kwargs():
144+
call_log: dict[str, Any] = {}
145+
146+
async def original_query(*args: Any, **kwargs: Any) -> AsyncIterable[str]:
147+
del args, kwargs
148+
if False:
149+
yield ""
150+
151+
class FakeWrappedClient:
152+
def __init__(self, *args: Any, **kwargs: Any):
153+
call_log["client_init_kwargs"] = kwargs
154+
155+
async def __aenter__(self) -> "FakeWrappedClient":
156+
return self
157+
158+
async def __aexit__(self, *args: Any) -> None:
159+
del args
160+
return None
161+
162+
async def query(self, *args: Any, **kwargs: Any) -> None:
163+
call_log["query_args"] = args
164+
call_log["query_kwargs"] = kwargs
165+
166+
async def receive_response(self) -> AsyncIterable[str]:
167+
yield "ok"
168+
169+
wrapped_query = _wrap_query_function(original_query, FakeWrappedClient)
170+
171+
messages = [
172+
message
173+
async for message in wrapped_query(
174+
"Say hi",
175+
options="opts",
176+
transport="transport",
177+
session_id="session-123",
178+
custom_flag=True,
179+
)
180+
]
181+
182+
assert messages == ["ok"]
183+
assert call_log["client_init_kwargs"] == {"options": "opts", "transport": "transport"}
184+
assert call_log["query_args"] == ("Say hi",)
185+
assert call_log["query_kwargs"] == {"session_id": "session-123", "custom_flag": True}
186+
187+
188+
def test_serialize_hook_context_handles_cyclic_containers():
189+
data: dict[str, Any] = {"name": "root", "signal": "omit-me"}
190+
data["child"] = {"parent": data}
191+
192+
serialized = _serialize_hook_context(data)
193+
194+
assert serialized is not None
195+
assert serialized["name"] == "root"
196+
assert serialized["child"]["parent"] == "<circular reference>"
197+
assert "signal" not in serialized
198+
199+
84200
@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed")
85201
@pytest.mark.asyncio
86202
async def test_calculator_with_multiple_operations(memory_logger):

0 commit comments

Comments
 (0)