Skip to content

Commit d3e1494

Browse files
committed
fixed issues
1 parent d647f86 commit d3e1494

6 files changed

Lines changed: 2728 additions & 24 deletions

File tree

src/tool_classifier/context_analyzer.py

Lines changed: 153 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def __init__(self, llm_manager: Any) -> None: # noqa: ANN401
207207
# Phase 1 & 2 modules for two-phase detection+generation flow
208208
self._detection_module: Optional[dspy.Module] = None
209209
self._response_generation_module: Optional[dspy.Module] = None
210-
self._stream_predictor: Optional[Any] = None
211210
logger.info("Context analyzer initialized")
212211

213212
def _format_conversation_history(
@@ -357,6 +356,111 @@ async def detect_context(
357356
)
358357
return result, cost_dict
359358

359+
async def detect_context_with_summary_fallback(
360+
self,
361+
query: str,
362+
conversation_history: List[Dict[str, Any]],
363+
) -> tuple[ContextDetectionResult, Dict[str, Any]]:
364+
"""
365+
Phase 1 with summary fallback: detect if query can be answered from history.
366+
367+
Implements a 3-step flow:
368+
1. Check the last 10 turns via detect_context().
369+
2. If cannot answer AND total history > 10 turns:
370+
- Generate a concise summary of the older turns (everything before the last 10).
371+
- Check whether the query can be answered from that summary.
372+
3. If still cannot answer, return can_answer=False (workflow falls back to RAG).
373+
374+
When the summary path succeeds, the returned ContextDetectionResult has:
375+
- can_answer_from_context=True
376+
- answered_from_summary=True
377+
- context_snippet set to the answer extracted from the summary, so that
378+
Phase 2 (stream_context_response / generate_context_response) can use it
379+
directly as the context for response generation.
380+
381+
Args:
382+
query: User query to classify
383+
conversation_history: Full conversation history
384+
385+
Returns:
386+
Tuple of (ContextDetectionResult, cost_dict)
387+
"""
388+
total_turns = len(conversation_history)
389+
390+
# Step 1: check the most recent 10 turns
391+
result, cost_dict = await self.detect_context(
392+
query=query, conversation_history=conversation_history
393+
)
394+
395+
# If already answered or it's a greeting, return immediately
396+
if result.is_greeting or result.can_answer_from_context:
397+
return result, cost_dict
398+
399+
# Step 2 & 3: if history exceeds 10 turns, try summary-based detection
400+
if total_turns > 10:
401+
logger.info(
402+
f"History has {total_turns} turns (> 10) | "
403+
f"Cannot answer from recent 10 | Attempting summary-based detection"
404+
)
405+
older_history = conversation_history[:-10]
406+
logger.info(f"Summarizing {len(older_history)} older turns")
407+
408+
try:
409+
summary, summary_cost = await self._generate_conversation_summary(
410+
older_history
411+
)
412+
cost_dict = self._merge_cost_dicts(cost_dict, summary_cost)
413+
414+
if summary:
415+
summary_result, analysis_cost = await self._analyze_from_summary(
416+
query=query, summary=summary
417+
)
418+
cost_dict = self._merge_cost_dicts(cost_dict, analysis_cost)
419+
420+
if summary_result.can_answer_from_context and summary_result.answer:
421+
logger.info(
422+
f"DETECTION: Can answer from summary | "
423+
f"Reasoning: {summary_result.reasoning}"
424+
)
425+
# Surface the summary-derived answer as context_snippet so
426+
# Phase 2 can generate a polished response from it.
427+
return ContextDetectionResult(
428+
is_greeting=False,
429+
can_answer_from_context=True,
430+
reasoning=summary_result.reasoning,
431+
context_snippet=summary_result.answer,
432+
answered_from_summary=True,
433+
), cost_dict
434+
435+
logger.info(
436+
"Cannot answer from summary either | Falling back to RAG"
437+
)
438+
else:
439+
logger.warning(
440+
"Summary generation returned empty | Falling back to RAG"
441+
)
442+
443+
except Exception as e:
444+
logger.error(f"Summary-based detection failed: {e}", exc_info=True)
445+
else:
446+
logger.info(
447+
f"History has {total_turns} turns (<= 10) | "
448+
f"No summary needed | Falling back to RAG"
449+
)
450+
451+
return result, cost_dict
452+
453+
@staticmethod
454+
def _yield_in_chunks(text: str, chunk_size: int = 5) -> list[str]:
455+
"""Split text into word-group chunks for simulated streaming."""
456+
words = text.split()
457+
chunks = []
458+
for i in range(0, len(words), chunk_size):
459+
group = words[i : i + chunk_size]
460+
trailing = " " if i + chunk_size < len(words) else ""
461+
chunks.append(" ".join(group) + trailing)
462+
return chunks
463+
360464
async def stream_context_response(
361465
self,
362466
query: str,
@@ -365,30 +469,39 @@ async def stream_context_response(
365469
"""
366470
Phase 2 (streaming): Stream a generated answer using DSPy native streaming.
367471
368-
Uses ContextResponseGenerationSignature with DSPy's streamify() so tokens
369-
are yielded in real time — same mechanism as ResponseGeneratorAgent.stream_response().
472+
Creates a fresh streamify predictor per call (avoids stale StreamListener
473+
issues that occur when the cached predictor is reused across calls).
474+
475+
Fallback chain:
476+
1. DSPy streamify → yield StreamResponse tokens as they arrive.
477+
2. If no stream tokens received but final Prediction has an answer,
478+
yield it in word-group chunks.
479+
3. If that is also empty, call generate_context_response() directly
480+
and yield the result in word-group chunks.
370481
371482
Args:
372483
query: The user query to answer
373484
context_snippet: Relevant context extracted during Phase 1 detection
374485
375486
Yields:
376-
Token strings as they arrive from the LLM
487+
Token strings as they arrive from the LLM (or simulated chunks)
377488
"""
378489
logger.info(f"CONTEXT GENERATOR: Phase 2 streaming | Query: '{query[:100]}'")
379490

380491
self.llm_manager.ensure_global_config()
381492
output_stream = None
382493
stream_started = False
494+
prediction_answer: Optional[str] = None
383495
try:
384496
with self.llm_manager.use_task_local():
385-
if self._stream_predictor is None:
386-
answer_listener = StreamListener(signature_field_name="answer")
387-
self._stream_predictor = dspy.streamify(
388-
dspy.Predict(ContextResponseGenerationSignature),
389-
stream_listeners=[answer_listener],
390-
)
391-
output_stream = self._stream_predictor(
497+
# Always create a fresh StreamListener + streamified predictor so that
498+
# the listener's internal state is clean for this call.
499+
answer_listener = StreamListener(signature_field_name="answer")
500+
stream_predictor: Any = dspy.streamify(
501+
dspy.Predict(ContextResponseGenerationSignature),
502+
stream_listeners=[answer_listener],
503+
)
504+
output_stream = stream_predictor(
392505
context_snippet=context_snippet,
393506
user_query=query,
394507
)
@@ -402,11 +515,11 @@ async def stream_context_response(
402515
logger.info(
403516
"Context response streaming complete (final Prediction received)"
404517
)
518+
if not stream_started:
519+
# Tokens didn't stream — extract answer from the Prediction
520+
# directly as first fallback before leaving the LM context.
521+
prediction_answer = getattr(chunk, "answer", "") or ""
405522

406-
if not stream_started:
407-
logger.warning(
408-
"Context streaming finished but no 'answer' tokens received."
409-
)
410523
except GeneratorExit:
411524
raise
412525
except Exception as e:
@@ -421,6 +534,31 @@ async def stream_context_response(
421534
f"Error during context stream cleanup: {cleanup_error}"
422535
)
423536

537+
if stream_started:
538+
return
539+
540+
# Fallback 1: answer was in the final Prediction but didn't stream as tokens
541+
if prediction_answer:
542+
logger.warning(
543+
"Stream tokens not received — yielding answer from final Prediction in chunks."
544+
)
545+
for text_chunk in self._yield_in_chunks(prediction_answer):
546+
yield text_chunk
547+
return
548+
549+
# Fallback 2: Prediction had no answer either — call generate_context_response
550+
logger.warning(
551+
"No answer from streamify — falling back to generate_context_response."
552+
)
553+
fallback_answer, _ = await self.generate_context_response(
554+
query=query, context_snippet=context_snippet
555+
)
556+
if fallback_answer:
557+
for text_chunk in self._yield_in_chunks(fallback_answer):
558+
yield text_chunk
559+
else:
560+
logger.error("All Phase 2 streaming fallbacks exhausted — empty response.")
561+
424562
async def generate_context_response(
425563
self,
426564
query: str,

src/tool_classifier/workflows/context_workflow.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Context workflow executor - Layer 2: Conversation history and greetings."""
22

3-
from typing import Any, AsyncIterator, Dict, Optional
3+
from typing import Any, AsyncIterator, Dict, Optional, cast
44
import time
55
import dspy
66
from loguru import logger
@@ -77,10 +77,19 @@ async def _detect(
7777
time_metric: Dict[str, float],
7878
costs_metric: Dict[str, Dict[str, Any]],
7979
) -> Optional[ContextDetectionResult]:
80-
"""Phase 1: run context detection. Returns ContextDetectionResult or None on error."""
80+
"""Phase 1: run context detection with summary fallback.
81+
82+
Checks the last 10 conversation turns first. If the query cannot be
83+
answered from those and the history exceeds 10 turns, falls back to a
84+
summary-based check over the older turns. Returns None on error so the
85+
caller falls through to RAG.
86+
"""
8187
try:
8288
start = time.time()
83-
result, cost = await self.context_analyzer.detect_context(
89+
(
90+
result,
91+
cost,
92+
) = await self.context_analyzer.detect_context_with_summary_fallback(
8493
query=message, conversation_history=history
8594
)
8695
time_metric["context.detection"] = time.time() - start
@@ -267,12 +276,29 @@ async def execute_async(
267276
language = detect_language(request.message)
268277
history = self._build_history(request)
269278

270-
detection_result = await self._detect(
271-
request.message, history, time_metric, costs_metric
272-
)
273-
if detection_result is None:
274-
self._log_costs(costs_metric)
275-
return None
279+
# Check if analysis is pre-computed (e.g. from classifier classify step)
280+
pre_computed = context.get("analysis_result")
281+
if (
282+
pre_computed is not None
283+
and hasattr(pre_computed, "is_greeting")
284+
and hasattr(pre_computed, "can_answer_from_context")
285+
):
286+
detection_result: ContextDetectionResult = cast(
287+
ContextDetectionResult, pre_computed
288+
)
289+
costs_metric.setdefault(
290+
"context_detection",
291+
{"total_cost": 0.0, "total_tokens": 0, "num_calls": 0},
292+
)
293+
else:
294+
_detected = await self._detect(
295+
request.message, history, time_metric, costs_metric
296+
)
297+
if _detected is None:
298+
self._log_costs(costs_metric)
299+
context["costs_dict"] = costs_metric
300+
return None
301+
detection_result = _detected
276302

277303
logger.info(
278304
f"[{request.chatId}] Detection: greeting={detection_result.is_greeting} "
@@ -286,6 +312,7 @@ async def execute_async(
286312
greeting_type=detection_result.greeting_type, language=language
287313
)
288314
self._log_costs(costs_metric)
315+
context["costs_dict"] = costs_metric
289316
return OrchestrationResponse(
290317
chatId=request.chatId,
291318
llmServiceActive=True,
@@ -298,6 +325,7 @@ async def execute_async(
298325
detection_result.can_answer_from_context
299326
and detection_result.context_snippet
300327
):
328+
context["costs_dict"] = costs_metric
301329
return await self._generate_response_async(
302330
request, detection_result.context_snippet, time_metric, costs_metric
303331
)
@@ -306,6 +334,7 @@ async def execute_async(
306334
f"[{request.chatId}] Cannot answer from context — falling back to RAG"
307335
)
308336
self._log_costs(costs_metric)
337+
context["costs_dict"] = costs_metric
309338
return None
310339

311340
async def execute_streaming(

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,12 @@
66
# Add the project root to Python path so tests can import from src
77
project_root = Path(__file__).parent.parent
88
sys.path.insert(0, str(project_root))
9+
10+
# Add src directory to Python path for direct module imports
11+
src_dir = project_root / "src"
12+
sys.path.insert(0, str(src_dir))
13+
14+
# Add models directory (sibling to src) for backward compatibility
15+
models_dir = project_root / "models"
16+
if models_dir.exists():
17+
sys.path.insert(0, str(models_dir.parent))

0 commit comments

Comments
 (0)