Skip to content

Commit 30bd893

Browse files
authored
Merge branch 'llm-430' into wip
2 parents 365ab39 + e1cab89 commit 30bd893

7 files changed

Lines changed: 1041 additions & 35 deletions

File tree

src/tool_classifier/agentic_loop.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Standalone agentic loop for multi-turn parameter collection."""
22

33
import asyncio
4-
from typing import Any, Dict, List
4+
from typing import Any, Dict, List, Optional
55

66
from loguru import logger
77

@@ -106,6 +106,7 @@ async def run_turn(
106106
awaiting_continuation: bool = False,
107107
continuation_turn: int = CONTINUATION_TURN,
108108
session_language: str = "en",
109+
continuation_language: Optional[str] = None,
109110
) -> AgenticLoopResult:
110111
"""Process one user turn of the parameter-collection loop.
111112
@@ -279,8 +280,9 @@ async def run_turn(
279280
turn_count,
280281
chat_id,
281282
)
283+
effective_continuation_lang = continuation_language or session_language
282284
continuation_q = _CONTINUATION_QUESTIONS.get(
283-
session_language, CONTINUATION_QUESTION
285+
effective_continuation_lang, CONTINUATION_QUESTION
284286
)
285287
await self._save_session(
286288
chat_id, merged_params, updated_turn_count, awaiting_continuation=True
@@ -314,6 +316,7 @@ async def stream_run_turn(
314316
awaiting_continuation: bool = False,
315317
continuation_turn: int = CONTINUATION_TURN,
316318
session_language: str = "en",
319+
continuation_language: Optional[str] = None,
317320
) -> tuple[AgenticLoopResult, List[str]]:
318321
"""Process one user turn like :meth:`run_turn` but stream clarifying_question tokens.
319322
@@ -454,8 +457,9 @@ async def stream_run_turn(
454457
turn_count,
455458
chat_id,
456459
)
460+
effective_continuation_lang = continuation_language or session_language
457461
continuation_q = _CONTINUATION_QUESTIONS.get(
458-
session_language, CONTINUATION_QUESTION
462+
effective_continuation_lang, CONTINUATION_QUESTION
459463
)
460464
await self._save_session(
461465
chat_id, merged_params, updated_turn_count, awaiting_continuation=True

src/tool_classifier/api_response_formatter.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class APIResponseFormatterSignature(dspy.Signature):
2424
- IGNORE the language of user_query for output language decisions — short follow-up
2525
messages are unreliable indicators. Always use response_language.
2626
27+
If custom_instructions is non-empty, follow those rules with HIGHEST PRIORITY —
28+
they override defaults (e.g. language policy, tone, formatting style).
29+
2730
Rules:
2831
- Format data in a readable way using bullet points, numbered lists, or natural prose.
2932
Do NOT return raw JSON or wrap content in code blocks.
@@ -69,6 +72,14 @@ class APIResponseFormatterSignature(dspy.Signature):
6972
"Always use this — do not infer language from api_response content."
7073
)
7174
)
75+
custom_instructions: str = dspy.InputField(
76+
desc=(
77+
"Optional system-level instructions configured by the organisation "
78+
"(e.g. 'Always respond in Estonian', 'Use structured format'). "
79+
"Empty string when no custom config is active. "
80+
"When non-empty, follow these rules with highest priority."
81+
)
82+
)
7283

7384
formatted_answer: str = dspy.OutputField(
7485
desc=(
@@ -95,10 +106,17 @@ class APIResponseFormatterSignature(dspy.Signature):
95106
class APIResponseFormatterModule(dspy.Module):
96107
"""DSPy Module that converts raw API JSON responses into natural-language answers."""
97108

98-
def __init__(self) -> None:
99-
"""Initialize formatter with a direct DSPy Predict."""
109+
def __init__(self, custom_instructions: str = "") -> None:
110+
"""Initialize formatter with a direct DSPy Predict.
111+
112+
Args:
113+
custom_instructions: Optional organisation-level prompt rules (e.g. language
114+
policy). Passed verbatim to the DSPy predictor on every call. Defaults
115+
to empty string (no custom config).
116+
"""
100117
super().__init__()
101118
self.formatter = dspy.Predict(APIResponseFormatterSignature)
119+
self._custom_instructions = custom_instructions
102120

103121
def forward(
104122
self,
@@ -131,6 +149,7 @@ def forward(
131149
api_response=normalized,
132150
endpoint_description=endpoint_description,
133151
response_language=response_language,
152+
custom_instructions=self._custom_instructions,
134153
)
135154
return result.formatted_answer # type: ignore[no-any-return]
136155

@@ -195,6 +214,7 @@ async def stream_forward(
195214
if detected_language in _FORMATTER_ERROR_MESSAGES
196215
else "en"
197216
)
217+
output_stream = None
198218
try:
199219
normalized = self._normalize_response(api_response)
200220
normalized = self._annotate_empty(normalized)
@@ -207,6 +227,7 @@ async def stream_forward(
207227
api_response=normalized,
208228
endpoint_description=endpoint_description,
209229
response_language=response_language,
230+
custom_instructions=self._custom_instructions,
210231
)
211232

212233
stream_started = False
@@ -255,6 +276,12 @@ async def stream_forward(
255276
f"APIResponseFormatterModule.stream_forward failed: {e}", exc_info=True
256277
)
257278
yield get_localized_message(_FORMATTER_ERROR_MESSAGES, safe_language)
279+
finally:
280+
if output_stream is not None:
281+
try:
282+
await output_stream.aclose()
283+
except Exception as cleanup_error:
284+
logger.debug(f"Error during stream cleanup: {cleanup_error}")
258285

259286
# ------------------------------------------------------------------
260287

src/tool_classifier/param_extractor.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import json
5+
import re
56
from datetime import datetime, timezone
67
from typing import Any, Dict, List, Optional, TypedDict
78

@@ -15,6 +16,33 @@
1516

1617
_MAX_HISTORY_TURNS = 5
1718

19+
# Regex patterns to strip format hints from parameter descriptions before
20+
# they are fed to the question-generation prompt. This prevents the LLM
21+
# from including format instructions (e.g. "YYYY-MM-DD") in its questions.
22+
_FORMAT_HINT_PATTERNS: List[re.Pattern[str]] = [
23+
# Parenthesised format hints: (YYYY-MM-DD), (ISO 8601), (2-letter code), (HH:MM:SS)
24+
re.compile(
25+
r"\s*\([^)]*(?:YYYY|MM|DD|HH|SS|ISO\s*\d*|letter|format)[^)]*\)",
26+
re.IGNORECASE,
27+
),
28+
# Trailing phrases: "in format YYYY-MM-DD" or "in the format YYYY-MM-DD"
29+
re.compile(r"\s*,?\s*in\s+(?:the\s+)?format\s+\S+", re.IGNORECASE),
30+
]
31+
32+
33+
def _strip_format_hints(description: str) -> str:
34+
"""Remove format hints from a parameter description.
35+
36+
Strips patterns such as ``(YYYY-MM-DD)``, ``(ISO 8601)``,
37+
``(2-letter code)``, ``(HH:MM:SS)``, and trailing
38+
``in the format YYYY-MM-DD`` phrases. The sanitised description is used
39+
only for LLM question generation; the original description (with format
40+
hints intact) is still used for extraction context.
41+
"""
42+
for pattern in _FORMAT_HINT_PATTERNS:
43+
description = pattern.sub("", description)
44+
return description.strip()
45+
1846

1947
class ParamExtractionResult(TypedDict):
2048
"""Return contract for ParamExtractionModule.forward()."""
@@ -34,6 +62,9 @@ class ParamExtractionSignature(dspy.Signature):
3462
short follow-up messages ("I'm not sure", "2026-01-01") are unreliable indicators.
3563
Always use session_language.
3664
65+
If custom_instructions is non-empty, follow those rules with HIGHEST PRIORITY —
66+
they override defaults (e.g. language policy, tone) for the clarifying_question output.
67+
3768
Extraction rules:
3869
- Extract values for ALL parameters listed in params_schema that appear in user_message
3970
or conversation_history, regardless of whether they are already in already_collected
@@ -42,6 +73,12 @@ class ParamExtractionSignature(dspy.Signature):
4273
- Only skip extraction for a param if the user has NOT mentioned it at all in this turn
4374
- Validate types: dates must be ISO 8601 (YYYY-MM-DD), integers must be whole numbers,
4475
numbers must be numeric, booleans must be true or false
76+
- SINGLE-VALUE ASSIGNMENT RULE: When the user's message contains exactly ONE value of a
77+
given type (e.g. one date) and MULTIPLE required parameters of the same type are still
78+
missing (e.g. both startDate and endDate are missing), assign that single value to the
79+
FIRST such missing required parameter in the order they appear in params_schema — never
80+
to a later one. For example, if startDate appears before endDate in params_schema and
81+
both are missing, a lone date like "2026-04-01" must be assigned to startDate, not endDate.
4582
4683
missing_required rules:
4784
- List every required parameter (required=true in schema) whose value is absent
@@ -59,6 +96,11 @@ class ParamExtractionSignature(dspy.Signature):
5996
- Use each missing parameter's description field to phrase the question naturally
6097
(e.g., "Which country and date would you like to use?" not "Provide countryIsoCode and startDate")
6198
- Never expose raw parameter names (camelCase identifiers) to the user
99+
- NEVER include format requirements, expected formats, format examples, or
100+
structural hints (such as "YYYY-MM-DD", "ISO 8601", "2-letter code",
101+
"in the format...") in the question — only ask WHAT information is needed,
102+
not HOW it should be formatted. The system handles format conversion
103+
internally from any natural-language input the user provides.
62104
"""
63105

64106
user_message: str = dspy.InputField(
@@ -85,6 +127,14 @@ class ParamExtractionSignature(dspy.Signature):
85127
"still extract the new value — corrections are allowed."
86128
)
87129
)
130+
custom_instructions: str = dspy.InputField(
131+
desc=(
132+
"Optional system-level instructions configured by the organisation "
133+
"(e.g. 'Always respond in Estonian', 'Use formal tone'). "
134+
"Empty string when no custom config is active. "
135+
"When non-empty, follow these rules with highest priority for the clarifying_question."
136+
)
137+
)
88138

89139
extracted_params: str = dspy.OutputField(
90140
desc='Valid JSON object of newly extracted parameters only: {"param_name": value}. Empty object {} if nothing new found.'
@@ -93,17 +143,29 @@ class ParamExtractionSignature(dspy.Signature):
93143
desc='Valid JSON array of required parameter names still missing after extraction: ["param1", "param2"]. Empty array [] if all required params are satisfied.'
94144
)
95145
clarifying_question: str = dspy.OutputField(
96-
desc='A single natural-language question that asks for ALL missing parameters at once, or the literal string "none" if all required params are collected.'
146+
desc=(
147+
"A single natural-language question that asks for ALL missing parameters "
148+
'at once, or the literal string "none" if all required params are collected. '
149+
'Never include format instructions or examples (e.g. "YYYY-MM-DD", '
150+
'"ISO 8601", "2-letter code") — only ask what information is needed.'
151+
)
97152
)
98153

99154

100155
class ParamExtractionModule(dspy.Module):
101156
"""DSPy Module for API parameter extraction from natural language."""
102157

103-
def __init__(self) -> None:
104-
"""Initialize param extraction module with Predict (direct prediction)."""
158+
def __init__(self, custom_instructions: str = "") -> None:
159+
"""Initialize param extraction module with Predict (direct prediction).
160+
161+
Args:
162+
custom_instructions: Optional organisation-level prompt rules (e.g. language
163+
policy). Passed verbatim to the DSPy predictor on every call. Defaults
164+
to empty string (no custom config).
165+
"""
105166
super().__init__()
106167
self.extractor = dspy.Predict(ParamExtractionSignature)
168+
self._custom_instructions = custom_instructions
107169

108170
def forward(
109171
self,
@@ -130,7 +192,13 @@ def forward(
130192
already_collected = already_collected or {}
131193

132194
history_text = self._format_conversation_history(conversation_history)
133-
params_schema_json = json.dumps(params_schema, ensure_ascii=False)
195+
sanitized_schema = [
196+
{**p, "description": _strip_format_hints(p.get("description", ""))}
197+
if isinstance(p, dict)
198+
else p
199+
for p in params_schema
200+
]
201+
params_schema_json = json.dumps(sanitized_schema, ensure_ascii=False)
134202
already_collected_json = json.dumps(already_collected, ensure_ascii=False)
135203

136204
result = None
@@ -141,6 +209,7 @@ def forward(
141209
session_language=session_language,
142210
params_schema=params_schema_json,
143211
already_collected=already_collected_json,
212+
custom_instructions=self._custom_instructions,
144213
)
145214
return self._parse_prediction(result, params_schema, already_collected)
146215

@@ -206,9 +275,16 @@ async def stream_forward(
206275
already_collected = already_collected or {}
207276

208277
history_text = self._format_conversation_history(conversation_history)
209-
params_schema_json = json.dumps(params_schema, ensure_ascii=False)
278+
sanitized_schema = [
279+
{**p, "description": _strip_format_hints(p.get("description", ""))}
280+
if isinstance(p, dict)
281+
else p
282+
for p in params_schema
283+
]
284+
params_schema_json = json.dumps(sanitized_schema, ensure_ascii=False)
210285
already_collected_json = json.dumps(already_collected, ensure_ascii=False)
211286

287+
output_stream = None
212288
try:
213289
stream_predictor = self._get_stream_predictor()
214290
output_stream = stream_predictor(
@@ -217,6 +293,7 @@ async def stream_forward(
217293
session_language=session_language,
218294
params_schema=params_schema_json,
219295
already_collected=already_collected_json,
296+
custom_instructions=self._custom_instructions,
220297
)
221298

222299
tokens: List[str] = []
@@ -273,6 +350,15 @@ async def stream_forward(
273350
logger.exception(f"ParamExtractionModule.stream_forward failed: {e}")
274351
return [], self._safe_defaults(params_schema, already_collected)
275352

353+
finally:
354+
if output_stream is not None:
355+
try:
356+
await output_stream.aclose()
357+
except Exception as cleanup_error:
358+
logger.debug(
359+
f"Error during param extraction stream cleanup: {cleanup_error}"
360+
)
361+
276362
# ------------------------------------------------------------------
277363
# Private helpers
278364
# ------------------------------------------------------------------
@@ -448,6 +534,31 @@ def _parse_prediction(
448534
)
449535
type_invalid_params.append(param_name)
450536

537+
# SINGLE-VALUE REASSIGNMENT: if the LLM assigned a value to a later same-type
538+
# param while an earlier same-type param is still missing, move the value forward.
539+
# This fixes the common case where a lone date like "2026-04-01" is extracted as
540+
# endDate when startDate is still missing.
541+
combined_after_extraction = {**already_collected, **validated_params}
542+
required_schema_order = [
543+
p for p in params_schema if isinstance(p, dict) and p.get("required", False)
544+
]
545+
for idx, missing_entry in enumerate(required_schema_order):
546+
m_name = missing_entry["name"]
547+
m_type = missing_entry.get("type", "string")
548+
if m_name in combined_after_extraction:
549+
continue # already satisfied
550+
# Find the first later param with the same type that was just extracted
551+
for later_entry in required_schema_order[idx + 1 :]:
552+
l_name = later_entry["name"]
553+
l_type = later_entry.get("type", "string")
554+
if l_type == m_type and l_name in validated_params:
555+
logger.debug(
556+
f"ParamExtractor: reassigning '{l_name}' → '{m_name}' "
557+
f"(single {m_type} value assigned to wrong param by LLM)"
558+
)
559+
validated_params[m_name] = validated_params.pop(l_name)
560+
break
561+
451562
# Re-derive missing required params after type validation.
452563
# validated_params (current turn) takes precedence over already_collected
453564
# so that explicit user corrections override prior values.

0 commit comments

Comments
 (0)