Skip to content

Commit f57fe5f

Browse files
feat(closes OPEN-9557): add 'promote' parameter to the trace decorator
1 parent 9d9b283 commit f57fe5f

File tree

1 file changed

+122
-54
lines changed

1 file changed

+122
-54
lines changed

src/openlayer/lib/tracing/tracer.py

Lines changed: 122 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -522,12 +522,26 @@ def trace(
522522
inference_pipeline_id: Optional[str] = None,
523523
context_kwarg: Optional[str] = None,
524524
question_kwarg: Optional[str] = None,
525+
promote: Optional[Union[List[str], Dict[str, str]]] = None,
525526
guardrails: Optional[List[Any]] = None,
526527
on_flush_failure: Optional[OnFlushFailureCallback] = None,
527528
**step_kwargs,
528529
):
529530
"""Decorator to trace a function with optional guardrails.
530531
532+
Parameters
533+
----------
534+
promote : list of str or dict mapping str to str, optional
535+
Kwarg names whose values should be surfaced as top-level columns in the
536+
trace data. Pass a list to use the original kwarg names as column names,
537+
or a dict to alias them::
538+
539+
# List form – uses original kwarg names
540+
@tracer.trace(promote=["tool_call_count", "user_query"])
541+
542+
# Dict form – maps kwarg_name -> column_name
543+
@tracer.trace(promote={"user_query": "agent_input_query"})
544+
531545
Examples
532546
--------
533547
@@ -614,6 +628,7 @@ def __next__(self):
614628
context_kwarg=context_kwarg,
615629
question_kwarg=question_kwarg,
616630
)
631+
_apply_promote_kwargs(self._inputs, promote)
617632
self._trace_initialized = True
618633

619634
try:
@@ -709,6 +724,7 @@ def wrapper(*func_args, **func_kwargs):
709724
context_kwarg=context_kwarg,
710725
question_kwarg=question_kwarg,
711726
)
727+
_apply_promote_kwargs(original_inputs, promote)
712728

713729
# Apply input guardrails
714730
modified_inputs, input_guardrail_metadata = (
@@ -811,6 +827,7 @@ def trace_async(
811827
inference_pipeline_id: Optional[str] = None,
812828
context_kwarg: Optional[str] = None,
813829
question_kwarg: Optional[str] = None,
830+
promote: Optional[Union[List[str], Dict[str, str]]] = None,
814831
guardrails: Optional[List[Any]] = None,
815832
on_flush_failure: Optional[OnFlushFailureCallback] = None,
816833
**step_kwargs,
@@ -821,6 +838,19 @@ def trace_async(
821838
function
822839
or an async generator and handles both cases appropriately.
823840
841+
Parameters
842+
----------
843+
promote : list of str or dict mapping str to str, optional
844+
Kwarg names whose values should be surfaced as top-level columns in the
845+
trace data. Pass a list to use the original kwarg names as column names,
846+
or a dict to alias them::
847+
848+
# List form – uses original kwarg names
849+
@tracer.trace_async(promote=["job_id", "user_query"])
850+
851+
# Dict form – maps kwarg_name -> column_name
852+
@tracer.trace_async(promote={"user_query": "agent_input_query"})
853+
824854
Examples
825855
--------
826856
@@ -886,6 +916,7 @@ async def __anext__(self):
886916
context_kwarg=context_kwarg,
887917
question_kwarg=question_kwarg,
888918
)
919+
_apply_promote_kwargs(self._inputs, promote)
889920
self._trace_initialized = True
890921

891922
try:
@@ -939,8 +970,8 @@ async def async_function_wrapper(*func_args, **func_kwargs):
939970
guardrail_metadata = {}
940971

941972
try:
942-
# Apply input guardrails if provided
943-
if guardrails:
973+
# Apply promote / input guardrails if provided
974+
if promote or guardrails:
944975
try:
945976
inputs = _extract_function_inputs(
946977
func_signature=func_signature,
@@ -949,35 +980,43 @@ async def async_function_wrapper(*func_args, **func_kwargs):
949980
context_kwarg=context_kwarg,
950981
question_kwarg=question_kwarg,
951982
)
952-
953-
# Process inputs through guardrails
954-
modified_inputs, input_metadata = (
955-
_apply_input_guardrails(
956-
guardrails,
957-
inputs,
983+
_apply_promote_kwargs(inputs, promote)
984+
985+
if guardrails:
986+
# Process inputs through guardrails
987+
modified_inputs, input_metadata = (
988+
_apply_input_guardrails(
989+
guardrails,
990+
inputs,
991+
)
958992
)
959-
)
960-
guardrail_metadata.update(input_metadata)
961-
962-
# Execute function with potentially modified inputs
963-
if modified_inputs != inputs:
964-
# Reconstruct function arguments from modified inputs
965-
bound = func_signature.bind(
966-
*func_args, **func_kwargs
967-
)
968-
bound.apply_defaults()
969-
970-
# Update bound arguments with modified values
971-
for (
972-
param_name,
973-
modified_value,
974-
) in modified_inputs.items():
975-
if param_name in bound.arguments:
976-
bound.arguments[param_name] = (
977-
modified_value
978-
)
979-
980-
output = await func(*bound.args, **bound.kwargs)
993+
guardrail_metadata.update(input_metadata)
994+
995+
# Execute function with potentially modified inputs
996+
if modified_inputs != inputs:
997+
# Reconstruct function arguments from modified inputs
998+
bound = func_signature.bind(
999+
*func_args, **func_kwargs
1000+
)
1001+
bound.apply_defaults()
1002+
1003+
# Update bound arguments with modified values
1004+
for (
1005+
param_name,
1006+
modified_value,
1007+
) in modified_inputs.items():
1008+
if param_name in bound.arguments:
1009+
bound.arguments[param_name] = (
1010+
modified_value
1011+
)
1012+
1013+
output = await func(
1014+
*bound.args, **bound.kwargs
1015+
)
1016+
else:
1017+
output = await func(
1018+
*func_args, **func_kwargs
1019+
)
9811020
else:
9821021
output = await func(*func_args, **func_kwargs)
9831022
except Exception as e:
@@ -1042,8 +1081,8 @@ def sync_wrapper(*func_args, **func_kwargs):
10421081
output = exception = None
10431082
guardrail_metadata = {}
10441083
try:
1045-
# Apply input guardrails if provided
1046-
if guardrails:
1084+
# Apply promote / input guardrails if provided
1085+
if promote or guardrails:
10471086
try:
10481087
inputs = _extract_function_inputs(
10491088
func_signature=func_signature,
@@ -1052,33 +1091,39 @@ def sync_wrapper(*func_args, **func_kwargs):
10521091
context_kwarg=context_kwarg,
10531092
question_kwarg=question_kwarg,
10541093
)
1094+
_apply_promote_kwargs(inputs, promote)
10551095

1056-
# Process inputs through guardrails
1057-
modified_inputs, input_metadata = (
1058-
_apply_input_guardrails(
1059-
guardrails,
1060-
inputs,
1096+
if guardrails:
1097+
# Process inputs through guardrails
1098+
modified_inputs, input_metadata = (
1099+
_apply_input_guardrails(
1100+
guardrails,
1101+
inputs,
1102+
)
10611103
)
1062-
)
1063-
guardrail_metadata.update(input_metadata)
1104+
guardrail_metadata.update(input_metadata)
10641105

1065-
# Execute function with potentially modified inputs
1066-
if modified_inputs != inputs:
1067-
# Reconstruct function arguments from modified inputs
1068-
bound = func_signature.bind(
1069-
*func_args, **func_kwargs
1070-
)
1071-
bound.apply_defaults()
1106+
# Execute function with potentially modified inputs
1107+
if modified_inputs != inputs:
1108+
# Reconstruct function arguments from modified inputs
1109+
bound = func_signature.bind(
1110+
*func_args, **func_kwargs
1111+
)
1112+
bound.apply_defaults()
10721113

1073-
# Update bound arguments with modified values
1074-
for (
1075-
param_name,
1076-
modified_value,
1077-
) in modified_inputs.items():
1078-
if param_name in bound.arguments:
1079-
bound.arguments[param_name] = modified_value
1114+
# Update bound arguments with modified values
1115+
for (
1116+
param_name,
1117+
modified_value,
1118+
) in modified_inputs.items():
1119+
if param_name in bound.arguments:
1120+
bound.arguments[param_name] = (
1121+
modified_value
1122+
)
10801123

1081-
output = func(*bound.args, **bound.kwargs)
1124+
output = func(*bound.args, **bound.kwargs)
1125+
else:
1126+
output = func(*func_args, **func_kwargs)
10821127
else:
10831128
output = func(*func_args, **func_kwargs)
10841129
except Exception as e:
@@ -1818,6 +1863,29 @@ def _extract_function_inputs(
18181863
return inputs
18191864

18201865

1866+
def _apply_promote_kwargs(
1867+
inputs: dict,
1868+
promote: Optional[Union[List[str], Dict[str, str]]],
1869+
) -> None:
1870+
"""Promote selected function kwargs to trace-level columns."""
1871+
if not promote:
1872+
return
1873+
mapping: Dict[str, str] = (
1874+
{k: k for k in promote} if isinstance(promote, list) else promote
1875+
)
1876+
resolved: Dict[str, Any] = {}
1877+
for kwarg_name, column_name in mapping.items():
1878+
if kwarg_name in inputs:
1879+
resolved[column_name] = inputs[kwarg_name]
1880+
else:
1881+
logger.warning(
1882+
"promote: kwarg `%s` not found in inputs of the current function.",
1883+
kwarg_name,
1884+
)
1885+
if resolved:
1886+
update_current_trace(**resolved)
1887+
1888+
18211889
def _finalize_step_logging(
18221890
step: steps.Step,
18231891
inputs: dict,

0 commit comments

Comments
 (0)