Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/graph/workflows/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def run_analysis_workflow(
import mlflow

from src.graph.state import AnalysisType
from src.tracing.context import session_context
from src.tracing.context import get_session_id, session_context

# Map string to enum
try:
Expand All @@ -146,8 +146,11 @@ async def run_analysis_workflow(
graph = build_analysis_graph(ha_client=ha_client, session=session)
compiled = graph.compile()

# Run with tracing
with session_context() as session_id, start_experiment_run("analysis_workflow") as run:
# Run with tracing (inherit parent session if one exists)
with (
session_context(get_session_id()) as session_id,
start_experiment_run("analysis_workflow") as run,
):
if run:
initial_state.mlflow_run_id = run.info.run_id if hasattr(run, "info") else None

Expand Down
15 changes: 9 additions & 6 deletions src/graph/workflows/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def run_conversation_workflow(
"""
from langchain_core.messages import HumanMessage

from src.tracing.context import session_context
from src.tracing.context import get_session_id, session_context

# Compile graph
compiled = compile_conversation_graph(session=session, thread_id=thread_id)
Expand All @@ -195,10 +195,13 @@ async def run_conversation_workflow(
messages=[HumanMessage(content=user_message)],
)

# Run with MLflow tracking and session context
# Run with MLflow tracking and session context (inherit parent session if one exists)
import mlflow

with session_context() as session_id, start_experiment_run("conversation_workflow"):
with (
session_context(get_session_id()) as session_id,
start_experiment_run("conversation_workflow"),
):
mlflow.set_tag("workflow", "conversation")
mlflow.set_tag("thread_id", thread_id or state.conversation_id)
mlflow.set_tag("session.id", session_id)
Expand Down Expand Up @@ -246,7 +249,7 @@ async def resume_after_approval(
Updated conversation state after resumption
"""
from src.dal import ProposalRepository
from src.tracing.context import session_context
from src.tracing.context import get_session_id, session_context

# Compile graph with same checkpointer
compiled = compile_conversation_graph(session=session, thread_id=thread_id)
Expand Down Expand Up @@ -279,10 +282,10 @@ async def resume_after_approval(
# Update the state in the graph
compiled.update_state(config, current_state.model_dump()) # type: ignore[attr-defined]

# Resume execution with session context
# Resume execution with session context (inherit parent session if one exists)
import mlflow

with session_context() as session_id:
with session_context(get_session_id()) as session_id:
with start_experiment_run("conversation_workflow_resume"):
mlflow.set_tag("workflow", "conversation_resume")
mlflow.set_tag("thread_id", thread_id)
Expand Down
9 changes: 6 additions & 3 deletions src/graph/workflows/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ async def run_discovery_workflow(
Returns:
Final discovery state
"""
# Start a trace session for this workflow
from src.tracing.context import session_context
# Start a trace session for this workflow (inherit parent session if one exists)
from src.tracing.context import get_session_id, session_context

# Build the graph with injected dependencies
graph = build_discovery_graph(ha_client=ha_client, session=session)
Expand All @@ -152,7 +152,10 @@ async def run_discovery_workflow(
# Run with MLflow tracking and session context
import mlflow

with session_context() as session_id, start_experiment_run("discovery_workflow"):
with (
session_context(get_session_id()) as session_id,
start_experiment_run("discovery_workflow"),
):
mlflow.set_tag("workflow", "discovery")
mlflow.set_tag("session.id", session_id)

Expand Down
15 changes: 11 additions & 4 deletions src/graph/workflows/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ async def run_optimization_workflow(
Returns:
Final analysis state
"""
from src.tracing import log_metric, log_param
from src.tracing.context import session_context
import mlflow

from src.tracing import log_metric, log_param, start_experiment_run
from src.tracing.context import get_session_id, session_context

# Map string to enum
type_map = {
Expand All @@ -147,8 +149,13 @@ async def run_optimization_workflow(
}
analysis_enum = type_map.get(analysis_type, AnalysisType.BEHAVIOR_ANALYSIS)

with session_context():
log_param("workflow", "optimization")
# Run with tracing (inherit parent session if one exists)
with (
session_context(get_session_id()) as session_id,
start_experiment_run("optimization_workflow"),
):
mlflow.set_tag("workflow", "optimization")
mlflow.set_tag("session.id", session_id)
log_param("analysis_type", analysis_type)
log_param("hours", hours)

Expand Down
19 changes: 17 additions & 2 deletions src/tracing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,17 @@ def _get_traced(mlflow: Any) -> Callable[..., Any]:
)
return traced_func

def _tag_trace_session(mlflow: Any) -> None:
"""Tag the current trace with session ID for grouping."""
try:
from src.tracing.context import get_session_id

session_id = get_session_id()
if session_id:
mlflow.update_current_trace(tags={"mlflow.trace.session": session_id})
except Exception:
_logger.debug("Failed to tag trace with session ID", exc_info=True)

@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[misc]
if not _ensure_mlflow_initialized() or not _traces_available:
Expand All @@ -767,7 +778,9 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore

try:
traced = _get_traced(mlflow)
return await traced(*args, **kwargs) # type: ignore[misc, no-any-return]
result = await traced(*args, **kwargs) # type: ignore[misc, no-any-return]
_tag_trace_session(mlflow)
return result # type: ignore[no-any-return]
except Exception as e:
_disable_traces("span creation failed; backend rejected traces")
_logger.debug(f"Span creation failed, running without trace: {e}")
Expand All @@ -784,7 +797,9 @@ def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:

try:
traced = _get_traced(mlflow)
return traced(*args, **kwargs)
result = traced(*args, **kwargs)
_tag_trace_session(mlflow)
return result
except Exception as e:
_disable_traces("span creation failed; backend rejected traces")
_logger.debug(f"Span creation failed, running without trace: {e}")
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_tracing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,38 @@ def test_restores_on_exception(self):
pass
assert get_session_id() == "outer"
clear_session()

def test_inherits_parent_session_via_get_session_id(self):
"""Nested session_context(get_session_id()) reuses parent session."""
clear_session()
with session_context(session_id="parent-session") as outer_sid:
assert outer_sid == "parent-session"
# Simulate what nested workflows should do: pass get_session_id()
with session_context(session_id=get_session_id()) as inner_sid:
assert inner_sid == "parent-session"
assert get_session_id() == "parent-session"
# After inner exits, outer is restored
assert get_session_id() == "parent-session"
clear_session()

def test_creates_new_session_when_no_parent_exists(self):
"""session_context(get_session_id()) creates new session when no parent."""
clear_session()
assert get_session_id() is None
# get_session_id() returns None, so session_context(None) creates new
with session_context(session_id=get_session_id()) as sid:
assert isinstance(sid, str)
assert len(sid) == 36 # UUID format
assert get_session_id() == sid
clear_session()

def test_deeply_nested_sessions_all_share_parent(self):
"""Triple-nested session_context all share the same parent session."""
clear_session()
with session_context(session_id="root") as root_sid:
with session_context(session_id=get_session_id()) as mid_sid:
with session_context(session_id=get_session_id()) as inner_sid:
assert inner_sid == "root"
assert mid_sid == "root"
assert root_sid == "root"
clear_session()
82 changes: 82 additions & 0 deletions tests/unit/test_tracing_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,88 @@ async def my_async_func():
result = await my_async_func()
assert result == 99

def test_sync_decorator_tags_session(self):
"""trace_with_uri sets mlflow.trace.session tag when session is active."""
from src.tracing.context import clear_session, set_session_id
from src.tracing.mlflow import trace_with_uri

@trace_with_uri(name="test_session_tag")
def my_func():
return 42

mock_mlflow = MagicMock()
mock_traced = MagicMock(return_value=42)
mock_mlflow.trace.return_value = mock_traced

set_session_id("test-session-123")
try:
with (
patch("src.tracing.mlflow._ensure_mlflow_initialized", return_value=True),
patch("src.tracing.mlflow._traces_available", True),
patch("src.tracing.mlflow._safe_import_mlflow", return_value=mock_mlflow),
):
result = my_func()
assert result == 42
mock_mlflow.update_current_trace.assert_called_once_with(
tags={"mlflow.trace.session": "test-session-123"}
)
finally:
clear_session()

def test_sync_decorator_skips_session_tag_when_no_session(self):
"""trace_with_uri does not call update_current_trace when no session is active."""
from src.tracing.context import clear_session
from src.tracing.mlflow import trace_with_uri

@trace_with_uri(name="test_no_session")
def my_func():
return 42

mock_mlflow = MagicMock()
mock_traced = MagicMock(return_value=42)
mock_mlflow.trace.return_value = mock_traced

clear_session()
with (
patch("src.tracing.mlflow._ensure_mlflow_initialized", return_value=True),
patch("src.tracing.mlflow._traces_available", True),
patch("src.tracing.mlflow._safe_import_mlflow", return_value=mock_mlflow),
):
result = my_func()
assert result == 42
mock_mlflow.update_current_trace.assert_not_called()

async def test_async_decorator_tags_session(self):
"""trace_with_uri sets mlflow.trace.session tag on async functions."""
from src.tracing.context import clear_session, set_session_id
from src.tracing.mlflow import trace_with_uri

@trace_with_uri(name="test_async_session")
async def my_async_func():
return 99

mock_mlflow = MagicMock()

async def mock_traced_coro(*args, **kwargs):
return 99

mock_mlflow.trace.return_value = mock_traced_coro

set_session_id("async-session-456")
try:
with (
patch("src.tracing.mlflow._ensure_mlflow_initialized", return_value=True),
patch("src.tracing.mlflow._traces_available", True),
patch("src.tracing.mlflow._safe_import_mlflow", return_value=mock_mlflow),
):
result = await my_async_func()
assert result == 99
mock_mlflow.update_current_trace.assert_called_once_with(
tags={"mlflow.trace.session": "async-session-456"}
)
finally:
clear_session()


class TestEnableAutolog:
def test_skips_when_not_initialized(self):
Expand Down
2 changes: 2 additions & 0 deletions ui/src/api/client/conversations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export type StreamChunk =
export async function* streamChat(
model: string,
messages: import("@/lib/types").ChatMessage[],
conversationId?: string,
): AsyncGenerator<StreamChunk> {
const url = `${env.API_URL}/v1/chat/completions`;
const response = await fetch(url, {
Expand All @@ -87,6 +88,7 @@ export async function* streamChat(
model,
messages,
stream: true,
...(conversationId && { conversation_id: conversationId }),
}),
});

Expand Down
2 changes: 1 addition & 1 deletion ui/src/pages/chat/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ export function ChatPage() {
try {
let traceId: string | undefined;

for await (const chunk of streamChat(selectedModel, chatHistory)) {
for await (const chunk of streamChat(selectedModel, chatHistory, activeSessionId ?? undefined)) {
if (typeof chunk === "object" && "type" in chunk) {
if (chunk.type === "metadata") {
if (chunk.trace_id) {
Expand Down
Loading