33import logging
44import threading
55import time
6- from collections .abc import AsyncGenerator , AsyncIterable , Mapping
6+ from collections .abc import AsyncGenerator , AsyncIterable
77from typing import Any
88
9+ from braintrust .bt_json import bt_safe_deep_copy
910from braintrust .logger import start_span
1011from braintrust .span_types import SpanTypeAttribute
1112from braintrust .wrappers ._anthropic_utils import Wrapper , extract_anthropic_usage , finalize_anthropic_tokens
@@ -165,32 +166,8 @@ def _serialize_system_message(message: Any) -> dict[str, Any]:
165166
166167 return serialized
167168
168-
169- def _serialize_hook_value (value : Any ) -> Any :
170- if value is None or isinstance (value , (bool , int , float , str )):
171- return value
172-
173- if dataclasses .is_dataclass (value ):
174- return _serialize_hook_value (dataclasses .asdict (value ))
175-
176- if isinstance (value , Mapping ):
177- return {str (key ): _serialize_hook_value (item ) for key , item in value .items ()}
178-
179- if isinstance (value , (list , tuple )):
180- return [_serialize_hook_value (item ) for item in value ]
181-
182- if hasattr (value , "__dict__" ):
183- return {
184- key : _serialize_hook_value (item )
185- for key , item in vars (value ).items ()
186- if not key .startswith ("_" ) and not callable (item )
187- }
188-
189- return str (value )
190-
191-
192169def _serialize_hook_context (context : Any ) -> dict [str , Any ] | None :
193- serialized = _serialize_hook_value (context )
170+ serialized = bt_safe_deep_copy (context )
194171 if not isinstance (serialized , dict ):
195172 return None
196173
@@ -224,7 +201,7 @@ async def wrapped_hook(*args: Any, **kwargs: Any) -> Any:
224201 tool_use_id = kwargs .get ("tool_use_id" ) if "tool_use_id" in kwargs else (args [1 ] if len (args ) > 1 else None )
225202 context = kwargs .get ("context" ) if "context" in kwargs else (args [2 ] if len (args ) > 2 else None )
226203
227- span_input = {"input" : _serialize_hook_value (hook_input )}
204+ span_input = {"input" : bt_safe_deep_copy (hook_input )}
228205 if tool_use_id is not None :
229206 span_input ["tool_use_id" ] = str (tool_use_id )
230207
@@ -249,7 +226,7 @@ async def wrapped_hook(*args: Any, **kwargs: Any) -> Any:
249226 result = callback (* args , ** kwargs )
250227 if inspect .isawaitable (result ):
251228 result = await result
252- span .log (output = _serialize_hook_value (result ))
229+ span .log (output = bt_safe_deep_copy (result ))
253230 return result
254231
255232 wrapped_hook ._braintrust_wrapped_claude_hook = True # type: ignore[attr-defined]
@@ -281,6 +258,8 @@ def _wrap_options_hooks(options: Any) -> None:
281258 hooks_by_event = getattr (options , "hooks" , None )
282259 if not isinstance (hooks_by_event , dict ):
283260 return
261+ if getattr (options , "_braintrust_wrapped_claude_hooks_ref" , None ) is hooks_by_event :
262+ return
284263
285264 for event_name , matchers in list (hooks_by_event .items ()):
286265 if not isinstance (matchers , list ):
@@ -295,6 +274,11 @@ def _wrap_options_hooks(options: Any) -> None:
295274 except Exception :
296275 continue
297276
277+ try :
278+ setattr (options , "_braintrust_wrapped_claude_hooks_ref" , hooks_by_event )
279+ except Exception :
280+ pass
281+
298282
299283def _wrap_client_hooks (client : Any ) -> None :
300284 _wrap_options_hooks (getattr (client , "options" , None ))
@@ -406,9 +390,7 @@ async def _trace_message_stream(
406390 task_events .append (_serialize_system_message (message ))
407391
408392 yield message
409- except Exception :
410- raise
411- else :
393+
412394 if final_results :
413395 span .log (output = final_results [- 1 ])
414396 finally :
@@ -892,6 +874,7 @@ async def query(self, *args: Any, **kwargs: Any) -> Any:
892874 """Wrap query to capture the prompt and start time for tracing."""
893875 # Capture the time when query is called (when LLM call starts)
894876 self .__query_start_time = time .time ()
877+ # Re-wrap if client.options.hooks was replaced after client construction.
895878 _wrap_client_hooks (self .__client )
896879
897880 # Capture the prompt for use in receive_response
@@ -936,16 +919,13 @@ def _wrap_query_function(query_fn: Any, client_class: Any) -> Any:
936919 return query_fn
937920
938921 def wrapped_query (* args : Any , ** kwargs : Any ) -> AsyncGenerator [Any , None ]:
939- prompt = kwargs .get ("prompt" )
940- if prompt is None and args :
941- prompt = args [0 ]
942-
943- options = kwargs .get ("options" )
944- transport = kwargs .get ("transport" )
922+ query_kwargs = dict (kwargs )
923+ options = query_kwargs .pop ("options" , None )
924+ transport = query_kwargs .pop ("transport" , None )
945925
946926 async def traced_generator () -> AsyncGenerator [Any , None ]:
947927 async with client_class (options = options , transport = transport ) as client :
948- await client .query (prompt )
928+ await client .query (* args , ** query_kwargs )
949929 async for message in client .receive_response ():
950930 yield message
951931
0 commit comments