Skip to content

Commit b0726d1

Browse files
authored
Refactor AI wrapper callback helpers (#623)
1 parent 07a4c16 commit b0726d1

4 files changed

Lines changed: 79 additions & 162 deletions

File tree

posthog/ai/langchain/callbacks.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,9 @@ def on_chain_end(
173173
**kwargs: Any,
174174
):
175175
"""Capture a completed LangChain chain run as a trace or span."""
176-
self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
177-
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, outputs)
176+
self._capture_trace_or_span_run(
177+
"on_chain_end", "outputs", outputs, run_id, parent_run_id
178+
)
178179

179180
def on_chain_error(
180181
self,
@@ -185,8 +186,9 @@ def on_chain_error(
185186
**kwargs: Any,
186187
):
187188
"""Capture a failed LangChain chain run as a trace or span."""
188-
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
189-
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
189+
self._capture_trace_or_span_run(
190+
"on_chain_error", "error", error, run_id, parent_run_id
191+
)
190192

191193
def on_chat_model_start(
192194
self,
@@ -243,10 +245,9 @@ def on_llm_end(
243245
"""
244246
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
245247
"""
246-
self._log_debug_event(
247-
"on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs
248+
self._capture_generation_run(
249+
"on_llm_end", "response", response, run_id, parent_run_id, kwargs=kwargs
248250
)
249-
self._pop_run_and_capture_generation(run_id, parent_run_id, response)
250251

251252
def on_llm_error(
252253
self,
@@ -257,8 +258,9 @@ def on_llm_error(
257258
**kwargs: Any,
258259
):
259260
"""Capture a failed LLM run as a PostHog AI generation event."""
260-
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
261-
self._pop_run_and_capture_generation(run_id, parent_run_id, error)
261+
self._capture_generation_run(
262+
"on_llm_error", "error", error, run_id, parent_run_id
263+
)
262264

263265
def on_tool_start(
264266
self,
@@ -288,8 +290,9 @@ def on_tool_end(
288290
**kwargs: Any,
289291
) -> Any:
290292
"""Capture a completed LangChain tool run as a span."""
291-
self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)
292-
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, output)
293+
self._capture_trace_or_span_run(
294+
"on_tool_end", "output", output, run_id, parent_run_id
295+
)
293296

294297
def on_tool_error(
295298
self,
@@ -301,8 +304,9 @@ def on_tool_error(
301304
**kwargs: Any,
302305
) -> Any:
303306
"""Capture a failed LangChain tool run as a span."""
304-
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
305-
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
307+
self._capture_trace_or_span_run(
308+
"on_tool_error", "error", error, run_id, parent_run_id
309+
)
306310

307311
def on_retriever_start(
308312
self,
@@ -330,10 +334,9 @@ def on_retriever_end(
330334
**kwargs: Any,
331335
):
332336
"""Capture a completed LangChain retriever run as a span."""
333-
self._log_debug_event(
334-
"on_retriever_end", run_id, parent_run_id, documents=documents
337+
self._capture_trace_or_span_run(
338+
"on_retriever_end", "documents", documents, run_id, parent_run_id
335339
)
336-
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, documents)
337340

338341
def on_retriever_error(
339342
self,
@@ -345,8 +348,9 @@ def on_retriever_error(
345348
**kwargs: Any,
346349
) -> Any:
347350
"""Run when Retriever errors."""
348-
self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error)
349-
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
351+
self._capture_trace_or_span_run(
352+
"on_retriever_error", "error", error, run_id, parent_run_id
353+
)
350354

351355
def on_agent_action(
352356
self,
@@ -370,8 +374,36 @@ def on_agent_finish(
370374
**kwargs: Any,
371375
) -> Any:
372376
"""Capture a completed LangChain agent action as a span."""
373-
self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish)
374-
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, finish)
377+
self._capture_trace_or_span_run(
378+
"on_agent_finish", "finish", finish, run_id, parent_run_id
379+
)
380+
381+
def _capture_trace_or_span_run(
382+
self,
383+
event_name: str,
384+
payload_name: str,
385+
payload: Any,
386+
run_id: UUID,
387+
parent_run_id: Optional[UUID],
388+
):
389+
self._log_debug_event(
390+
event_name, run_id, parent_run_id, **{payload_name: payload}
391+
)
392+
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, payload)
393+
394+
def _capture_generation_run(
395+
self,
396+
event_name: str,
397+
payload_name: str,
398+
payload: Any,
399+
run_id: UUID,
400+
parent_run_id: Optional[UUID],
401+
**extra: Any,
402+
):
403+
self._log_debug_event(
404+
event_name, run_id, parent_run_id, **{payload_name: payload}, **extra
405+
)
406+
self._pop_run_and_capture_generation(run_id, parent_run_id, payload)
375407

376408
def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):
377409
"""

posthog/ai/openai/openai.py

Lines changed: 8 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from posthog.ai.sanitization import sanitize_openai, sanitize_openai_response
2727
from posthog.client import Client as PostHogClient
2828
from posthog import setup
29-
from posthog.ai.openai.wrapper_utils import warn_on_fallback
29+
from posthog.ai.openai.wrapper_utils import _OpenAIWrapperResource
3030

3131

3232
class OpenAI(openai.OpenAI):
@@ -91,18 +91,9 @@ def _parse_and_track(
9191
)
9292

9393

94-
class WrappedResponses:
94+
class WrappedResponses(_OpenAIWrapperResource):
9595
"""Wrapper for OpenAI responses that tracks usage in PostHog."""
9696

97-
def __init__(self, client: OpenAI, original_responses):
98-
self._client = client
99-
self._original = original_responses
100-
101-
def __getattr__(self, name):
102-
"""Fallback to original responses object for any methods we don't explicitly handle."""
103-
warn_on_fallback(self.__class__.__name__, name)
104-
return getattr(self._original, name)
105-
10697
def create(
10798
self,
10899
posthog_distinct_id: Optional[str] = None,
@@ -312,36 +303,18 @@ def parse(
312303
)
313304

314305

315-
class WrappedChat:
306+
class WrappedChat(_OpenAIWrapperResource):
316307
"""Wrapper for OpenAI chat that tracks usage in PostHog."""
317308

318-
def __init__(self, client: OpenAI, original_chat):
319-
self._client = client
320-
self._original = original_chat
321-
322-
def __getattr__(self, name):
323-
"""Fallback to original chat object for any methods we don't explicitly handle."""
324-
warn_on_fallback(self.__class__.__name__, name)
325-
return getattr(self._original, name)
326-
327309
@property
328310
def completions(self):
329311
"""Access chat completions with PostHog usage tracking."""
330312
return WrappedCompletions(self._client, self._original.completions)
331313

332314

333-
class WrappedCompletions:
315+
class WrappedCompletions(_OpenAIWrapperResource):
334316
"""Wrapper for OpenAI chat completions that tracks usage in PostHog."""
335317

336-
def __init__(self, client: OpenAI, original_completions):
337-
self._client = client
338-
self._original = original_completions
339-
340-
def __getattr__(self, name):
341-
"""Fallback to original completions object for any methods we don't explicitly handle."""
342-
warn_on_fallback(self.__class__.__name__, name)
343-
return getattr(self._original, name)
344-
345318
def parse(
346319
self,
347320
posthog_distinct_id: Optional[str] = None,
@@ -566,18 +539,9 @@ def _capture_streaming_event(
566539
capture_streaming_event(self._client._ph_client, event_data)
567540

568541

569-
class WrappedEmbeddings:
542+
class WrappedEmbeddings(_OpenAIWrapperResource):
570543
"""Wrapper for OpenAI embeddings that tracks usage in PostHog."""
571544

572-
def __init__(self, client: OpenAI, original_embeddings):
573-
self._client = client
574-
self._original = original_embeddings
575-
576-
def __getattr__(self, name):
577-
"""Fallback to original embeddings object for any methods we don't explicitly handle."""
578-
warn_on_fallback(self.__class__.__name__, name)
579-
return getattr(self._original, name)
580-
581545
def create(
582546
self,
583547
posthog_distinct_id: Optional[str] = None,
@@ -651,54 +615,27 @@ def create(
651615
return response
652616

653617

654-
class WrappedBeta:
618+
class WrappedBeta(_OpenAIWrapperResource):
655619
"""Wrapper for OpenAI beta features that tracks usage in PostHog."""
656620

657-
def __init__(self, client: OpenAI, original_beta):
658-
self._client = client
659-
self._original = original_beta
660-
661-
def __getattr__(self, name):
662-
"""Fallback to original beta object for any methods we don't explicitly handle."""
663-
warn_on_fallback(self.__class__.__name__, name)
664-
return getattr(self._original, name)
665-
666621
@property
667622
def chat(self):
668623
"""Access beta chat APIs with PostHog usage tracking."""
669624
return WrappedBetaChat(self._client, self._original.chat)
670625

671626

672-
class WrappedBetaChat:
627+
class WrappedBetaChat(_OpenAIWrapperResource):
673628
"""Wrapper for OpenAI beta chat that tracks usage in PostHog."""
674629

675-
def __init__(self, client: OpenAI, original_beta_chat):
676-
self._client = client
677-
self._original = original_beta_chat
678-
679-
def __getattr__(self, name):
680-
"""Fallback to original beta chat object for any methods we don't explicitly handle."""
681-
warn_on_fallback(self.__class__.__name__, name)
682-
return getattr(self._original, name)
683-
684630
@property
685631
def completions(self):
686632
"""Access beta chat completions with PostHog usage tracking."""
687633
return WrappedBetaCompletions(self._client, self._original.completions)
688634

689635

690-
class WrappedBetaCompletions:
636+
class WrappedBetaCompletions(_OpenAIWrapperResource):
691637
"""Wrapper for OpenAI beta chat completions that tracks usage in PostHog."""
692638

693-
def __init__(self, client: OpenAI, original_beta_completions):
694-
self._client = client
695-
self._original = original_beta_completions
696-
697-
def __getattr__(self, name):
698-
"""Fallback to original beta completions object for any methods we don't explicitly handle."""
699-
warn_on_fallback(self.__class__.__name__, name)
700-
return getattr(self._original, name)
701-
702639
def parse(
703640
self,
704641
posthog_distinct_id: Optional[str] = None,

0 commit comments

Comments
 (0)