@@ -236,6 +236,7 @@ def _get_client() -> Optional[Openlayer]:
236236_current_step = contextvars .ContextVar ("current_step" )
237237_current_trace = contextvars .ContextVar ("current_trace" )
238238_rag_context = contextvars .ContextVar ("rag_context" )
239+ _rag_question = contextvars .ContextVar ("rag_question" )
239240
240241# ----------------------------- Offline Buffer Implementation ----------------------------- #
241242
@@ -459,6 +460,11 @@ def get_rag_context() -> Optional[Dict[str, Any]]:
459460 return _rag_context .get (None )
460461
461462
463+ def get_rag_question () -> Optional [str ]:
464+ """Returns the current question."""
465+ return _rag_question .get (None )
466+
467+
462468@contextmanager
463469def create_step (
464470 name : str ,
@@ -515,6 +521,7 @@ def trace(
515521 * step_args ,
516522 inference_pipeline_id : Optional [str ] = None ,
517523 context_kwarg : Optional [str ] = None ,
524+ question_kwarg : Optional [str ] = None ,
518525 guardrails : Optional [List [Any ]] = None ,
519526 on_flush_failure : Optional [OnFlushFailureCallback ] = None ,
520527 ** step_kwargs ,
@@ -605,6 +612,7 @@ def __next__(self):
605612 func_args = func_args ,
606613 func_kwargs = func_kwargs ,
607614 context_kwarg = context_kwarg ,
615+ question_kwarg = question_kwarg ,
608616 )
609617 self ._trace_initialized = True
610618
@@ -699,6 +707,7 @@ def wrapper(*func_args, **func_kwargs):
699707 func_args = func_args ,
700708 func_kwargs = func_kwargs ,
701709 context_kwarg = context_kwarg ,
710+ question_kwarg = question_kwarg ,
702711 )
703712
704713 # Apply input guardrails
@@ -785,6 +794,7 @@ def wrapper(*func_args, **func_kwargs):
785794 context_kwarg = context_kwarg ,
786795 output = output ,
787796 guardrail_metadata = guardrail_metadata ,
797+ question_kwarg = question_kwarg ,
788798 )
789799
790800 if exception is not None :
@@ -800,6 +810,7 @@ def trace_async(
800810 * step_args ,
801811 inference_pipeline_id : Optional [str ] = None ,
802812 context_kwarg : Optional [str ] = None ,
813+ question_kwarg : Optional [str ] = None ,
803814 guardrails : Optional [List [Any ]] = None ,
804815 on_flush_failure : Optional [OnFlushFailureCallback ] = None ,
805816 ** step_kwargs ,
@@ -873,6 +884,7 @@ async def __anext__(self):
873884 func_args = func_args ,
874885 func_kwargs = func_kwargs ,
875886 context_kwarg = context_kwarg ,
887+ question_kwarg = question_kwarg ,
876888 )
877889 self ._trace_initialized = True
878890
@@ -935,6 +947,7 @@ async def async_function_wrapper(*func_args, **func_kwargs):
935947 func_args = func_args ,
936948 func_kwargs = func_kwargs ,
937949 context_kwarg = context_kwarg ,
950+ question_kwarg = question_kwarg ,
938951 )
939952
940953 # Process inputs through guardrails
@@ -990,6 +1003,7 @@ async def async_function_wrapper(*func_args, **func_kwargs):
9901003 func_args = func_args ,
9911004 func_kwargs = func_kwargs ,
9921005 context_kwarg = context_kwarg ,
1006+ question_kwarg = question_kwarg ,
9931007 ),
9941008 )
9951009 )
@@ -1010,6 +1024,7 @@ async def async_function_wrapper(*func_args, **func_kwargs):
10101024 context_kwarg = context_kwarg ,
10111025 output = output ,
10121026 guardrail_metadata = guardrail_metadata ,
1027+ question_kwarg = question_kwarg ,
10131028 )
10141029
10151030 return output
@@ -1035,6 +1050,7 @@ def sync_wrapper(*func_args, **func_kwargs):
10351050 func_args = func_args ,
10361051 func_kwargs = func_kwargs ,
10371052 context_kwarg = context_kwarg ,
1053+ question_kwarg = question_kwarg ,
10381054 )
10391055
10401056 # Process inputs through guardrails
@@ -1087,6 +1103,7 @@ def sync_wrapper(*func_args, **func_kwargs):
10871103 func_args = func_args ,
10881104 func_kwargs = func_kwargs ,
10891105 context_kwarg = context_kwarg ,
1106+ question_kwarg = question_kwarg ,
10901107 ),
10911108 )
10921109 guardrail_metadata .update (output_metadata )
@@ -1106,6 +1123,7 @@ def sync_wrapper(*func_args, **func_kwargs):
11061123 context_kwarg = context_kwarg ,
11071124 output = output ,
11081125 guardrail_metadata = guardrail_metadata ,
1126+ question_kwarg = question_kwarg ,
11091127 )
11101128
11111129 if exception is not None :
@@ -1147,6 +1165,18 @@ def log_context(context: List[str]) -> None:
11471165 logger .warning ("No current step found to log context." )
11481166
11491167
1168+ def log_question (question : str ) -> None :
1169+ """Logs the question to the current step of the trace.
1170+
1171+ The `question` parameter should be the user query string for RAG use cases."""
1172+ current_step = get_current_step ()
1173+ if current_step :
1174+ _rag_question .set (question )
1175+ current_step .log (metadata = {"_question" : question })
1176+ else :
1177+ logger .warning ("No current step found to log question." )
1178+
1179+
11501180def log_attachment (
11511181 data : Union [bytes , str , Path , Any ],
11521182 name : Optional [str ] = None ,
@@ -1630,6 +1660,8 @@ def _upload_and_publish_trace(
16301660 config .update ({"ground_truth_column_name" : "groundTruth" })
16311661 if "context" in trace_data :
16321662 config .update ({"context_column_name" : "context" })
1663+ if "_question" in trace_data :
1664+ config .update ({"question_column_name" : "_question" })
16331665
16341666 if prompt is not None :
16351667 config .update ({"prompt" : prompt })
@@ -1729,6 +1761,7 @@ def _process_wrapper_inputs_and_outputs(
17291761 context_kwarg : Optional [str ],
17301762 output : Any ,
17311763 guardrail_metadata : Optional [Dict [str , Any ]] = None ,
1764+ question_kwarg : Optional [str ] = None ,
17321765) -> None :
17331766 """Extract function inputs and finalize step logging - common pattern across
17341767 wrappers."""
@@ -1737,6 +1770,7 @@ def _process_wrapper_inputs_and_outputs(
17371770 func_args = func_args ,
17381771 func_kwargs = func_kwargs ,
17391772 context_kwarg = context_kwarg ,
1773+ question_kwarg = question_kwarg ,
17401774 )
17411775 _finalize_step_logging (
17421776 step = step ,
@@ -1752,6 +1786,7 @@ def _extract_function_inputs(
17521786 func_args : tuple ,
17531787 func_kwargs : dict ,
17541788 context_kwarg : Optional [str ] = None ,
1789+ question_kwarg : Optional [str ] = None ,
17551790) -> dict :
17561791 """Extract and clean function inputs for logging."""
17571792 bound = func_signature .bind (* func_args , ** func_kwargs )
@@ -1770,6 +1805,16 @@ def _extract_function_inputs(
17701805 context_kwarg ,
17711806 )
17721807
1808+ # Handle question kwarg if specified
1809+ if question_kwarg :
1810+ if question_kwarg in inputs :
1811+ log_question (inputs .get (question_kwarg ))
1812+ else :
1813+ logger .warning (
1814+ "Question kwarg `%s` not found in inputs of the current function." ,
1815+ question_kwarg ,
1816+ )
1817+
17731818 return inputs
17741819
17751820
@@ -1955,6 +2000,10 @@ def post_process_trace(
19552000 if context :
19562001 trace_data ["context" ] = context
19572002
2003+ question = get_rag_question ()
2004+ if question :
2005+ trace_data ["_question" ] = question
2006+
19582007 return trace_data , input_variable_names
19592008
19602009
0 commit comments