Skip to content

Commit 5490e12

Browse files
committed
insert llm abstraction layer
1 parent 176087c commit 5490e12

7 files changed

Lines changed: 928 additions & 148 deletions

File tree

mcp-server/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ dependencies = [
88
"neo4j>=5.0.0",
99
"pydantic>=2.0.0",
1010
"anthropic>=0.40.0",
11+
"openai>=1.0.0",
12+
"google-generativeai>=0.4.0",
1113
"numpy>=1.24.0",
1214
"flask>=3.0.0",
1315
"click>=8.0.0",
Lines changed: 85 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""LLM-based detection for context capture."""
22

33
import asyncio
4-
import json
54
import re
65
from dataclasses import dataclass
76
from typing import Optional
87

9-
from anthropic import AsyncAnthropic
10-
8+
from ccmemory.llmprovider import getLlmClient
119
from .prompts import (
1210
DECISION_PROMPT,
1311
CORRECTION_PROMPT,
@@ -16,18 +14,16 @@
1614
QUESTION_PROMPT,
1715
FAILED_APPROACH_PROMPT,
1816
)
17+
from .schemas import (
18+
DecisionResult,
19+
CorrectionResult,
20+
ExceptionResult,
21+
InsightResult,
22+
QuestionResult,
23+
FailedApproachResult,
24+
)
1925

2026
CONFIDENCE_THRESHOLD = 0.7
21-
DETECTION_MODEL = "claude-sonnet-4-20250514"
22-
23-
_client = None
24-
25-
26-
def _getClient():
27-
global _client
28-
if _client is None:
29-
_client = AsyncAnthropic()
30-
return _client
3127

3228

3329
@dataclass
@@ -37,8 +33,9 @@ class Detection:
3733
data: dict
3834

3935

40-
async def detectAll(user_message: str, claude_response: str,
41-
context: str) -> list[Detection]:
36+
async def detectAll(
37+
user_message: str, claude_response: str, context: str
38+
) -> list[Detection]:
4239
"""Run all detection prompts in parallel, filter by confidence."""
4340
if len(user_message.strip()) < 10:
4441
return []
@@ -63,144 +60,130 @@ async def detectAll(user_message: str, claude_response: str,
6360
return detections
6461

6562

66-
async def _callDetector(prompt: str) -> dict:
67-
"""Call LLM for classification."""
68-
client = _getClient()
69-
response = await client.messages.create(
70-
model=DETECTION_MODEL,
71-
max_tokens=500,
72-
messages=[{"role": "user", "content": prompt}]
73-
)
74-
text = response.content[0].text
75-
76-
try:
77-
start = text.find('{')
78-
end = text.rfind('}') + 1
79-
if start >= 0 and end > start:
80-
return json.loads(text[start:end])
81-
except json.JSONDecodeError:
82-
pass
83-
return {}
84-
85-
86-
async def detectDecision(user_message: str, claude_response: str,
87-
context: str) -> Optional[Detection]:
63+
async def detectDecision(
64+
user_message: str, claude_response: str, context: str
65+
) -> Optional[Detection]:
8866
prompt = DECISION_PROMPT.format(
8967
context=context[:500],
9068
claude_response=claude_response[:500],
91-
user_message=user_message
69+
user_message=user_message,
9270
)
93-
result = await _callDetector(prompt)
94-
if result.get("is_decision"):
71+
client = getLlmClient()
72+
result = await client.complete(prompt, DecisionResult, maxTokens=500)
73+
if result.isDecision:
9574
return Detection(
9675
type="decision",
97-
confidence=result.get("confidence", 0.5),
76+
confidence=result.confidence,
9877
data={
99-
"description": result.get("description", user_message[:200]),
100-
"rationale": result.get("rationale"),
101-
"revisit_trigger": result.get("revisit_trigger"),
102-
}
78+
"description": result.description or user_message[:200],
79+
"rationale": result.rationale,
80+
"revisit_trigger": result.revisitTrigger,
81+
},
10382
)
10483
return None
10584

10685

107-
async def detectCorrection(user_message: str, claude_response: str,
108-
context: str) -> Optional[Detection]:
86+
async def detectCorrection(
87+
user_message: str, claude_response: str, context: str
88+
) -> Optional[Detection]:
10989
prompt = CORRECTION_PROMPT.format(
110-
claude_response=claude_response[:500],
111-
user_message=user_message
90+
claude_response=claude_response[:500], user_message=user_message
11291
)
113-
result = await _callDetector(prompt)
114-
if result.get("is_correction"):
92+
client = getLlmClient()
93+
result = await client.complete(prompt, CorrectionResult, maxTokens=500)
94+
if result.isCorrection:
11595
return Detection(
11696
type="correction",
117-
confidence=result.get("confidence", 0.5),
97+
confidence=result.confidence,
11898
data={
119-
"wrong_belief": result.get("wrong_belief"),
120-
"right_belief": result.get("right_belief"),
121-
"severity": result.get("severity", "significant"),
122-
}
99+
"wrong_belief": result.wrongBelief,
100+
"right_belief": result.rightBelief,
101+
"severity": result.severity or "significant",
102+
},
123103
)
124104
return None
125105

126106

127-
async def detectException(user_message: str, claude_response: str,
128-
context: str) -> Optional[Detection]:
129-
prompt = EXCEPTION_PROMPT.format(
130-
context=context[:500],
131-
user_message=user_message
132-
)
133-
result = await _callDetector(prompt)
134-
if result.get("is_exception"):
107+
async def detectException(
108+
user_message: str, claude_response: str, context: str
109+
) -> Optional[Detection]:
110+
prompt = EXCEPTION_PROMPT.format(context=context[:500], user_message=user_message)
111+
client = getLlmClient()
112+
result = await client.complete(prompt, ExceptionResult, maxTokens=500)
113+
if result.isException:
135114
return Detection(
136115
type="exception",
137-
confidence=result.get("confidence", 0.5),
116+
confidence=result.confidence,
138117
data={
139-
"rule_broken": result.get("rule_broken"),
140-
"justification": result.get("justification"),
141-
"scope": result.get("scope", "one-time"),
142-
}
118+
"rule_broken": result.ruleBroken,
119+
"justification": result.justification,
120+
"scope": result.scope or "one-time",
121+
},
143122
)
144123
return None
145124

146125

147-
async def detectInsight(user_message: str, claude_response: str,
148-
context: str) -> Optional[Detection]:
126+
async def detectInsight(
127+
user_message: str, claude_response: str, context: str
128+
) -> Optional[Detection]:
149129
prompt = INSIGHT_PROMPT.format(
150130
context=context[:500],
151131
claude_response=claude_response[:500],
152-
user_message=user_message
132+
user_message=user_message,
153133
)
154-
result = await _callDetector(prompt)
155-
if result.get("is_insight"):
134+
client = getLlmClient()
135+
result = await client.complete(prompt, InsightResult, maxTokens=500)
136+
if result.isInsight:
156137
return Detection(
157138
type="insight",
158-
confidence=result.get("confidence", 0.5),
139+
confidence=result.confidence,
159140
data={
160-
"category": result.get("category", "realization"),
161-
"summary": result.get("summary"),
162-
"implications": result.get("implications"),
163-
}
141+
"category": result.category or "realization",
142+
"summary": result.summary,
143+
"implications": result.implications,
144+
},
164145
)
165146
return None
166147

167148

168-
async def detectQuestion(user_message: str, claude_response: str,
169-
context: str) -> Optional[Detection]:
149+
async def detectQuestion(
150+
user_message: str, claude_response: str, context: str
151+
) -> Optional[Detection]:
170152
prompt = QUESTION_PROMPT.format(
171-
claude_response=claude_response[:500],
172-
user_message=user_message
153+
claude_response=claude_response[:500], user_message=user_message
173154
)
174-
result = await _callDetector(prompt)
175-
if result.get("is_question"):
155+
client = getLlmClient()
156+
result = await client.complete(prompt, QuestionResult, maxTokens=500)
157+
if result.isQuestion:
176158
return Detection(
177159
type="question",
178-
confidence=result.get("confidence", 0.5),
160+
confidence=result.confidence,
179161
data={
180-
"question": result.get("question"),
181-
"answer": result.get("answer"),
182-
"context": result.get("context"),
183-
}
162+
"question": result.question,
163+
"answer": result.answer,
164+
"context": result.context,
165+
},
184166
)
185167
return None
186168

187169

188-
async def detectFailedApproach(user_message: str, claude_response: str,
189-
context: str) -> Optional[Detection]:
170+
async def detectFailedApproach(
171+
user_message: str, claude_response: str, context: str
172+
) -> Optional[Detection]:
190173
prompt = FAILED_APPROACH_PROMPT.format(
191-
context=context[:500],
192-
user_message=user_message
174+
context=context[:500], user_message=user_message
193175
)
194-
result = await _callDetector(prompt)
195-
if result.get("is_failed_approach"):
176+
client = getLlmClient()
177+
result = await client.complete(prompt, FailedApproachResult, maxTokens=500)
178+
if result.isFailedApproach:
196179
return Detection(
197180
type="failed_approach",
198-
confidence=result.get("confidence", 0.5),
181+
confidence=result.confidence,
199182
data={
200-
"approach": result.get("approach"),
201-
"outcome": result.get("outcome"),
202-
"lesson": result.get("lesson"),
203-
}
183+
"approach": result.approach,
184+
"outcome": result.outcome,
185+
"lesson": result.lesson,
186+
},
204187
)
205188
return None
206189

@@ -218,9 +201,5 @@ async def detectReference(user_message: str) -> Optional[Detection]:
218201
refs.append({"type": "file_path", "uri": path})
219202

220203
if refs:
221-
return Detection(
222-
type="reference",
223-
confidence=0.9,
224-
data={"references": refs}
225-
)
204+
return Detection(type="reference", confidence=0.9, data={"references": refs})
226205
return None
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Pydantic schemas for LLM detection outputs."""
2+
3+
from pydantic import BaseModel
4+
5+
6+
class DecisionResult(BaseModel):
7+
isDecision: bool
8+
confidence: float
9+
description: str | None = None
10+
rationale: str | None = None
11+
revisitTrigger: str | None = None
12+
13+
14+
class CorrectionResult(BaseModel):
15+
isCorrection: bool
16+
confidence: float
17+
wrongBelief: str | None = None
18+
rightBelief: str | None = None
19+
severity: str | None = None
20+
21+
22+
class ExceptionResult(BaseModel):
23+
isException: bool
24+
confidence: float
25+
ruleBroken: str | None = None
26+
justification: str | None = None
27+
scope: str | None = None
28+
29+
30+
class InsightResult(BaseModel):
31+
isInsight: bool
32+
confidence: float
33+
category: str | None = None
34+
summary: str | None = None
35+
implications: str | None = None
36+
37+
38+
class QuestionResult(BaseModel):
39+
isQuestion: bool
40+
confidence: float
41+
question: str | None = None
42+
answer: str | None = None
43+
context: str | None = None
44+
45+
46+
class FailedApproachResult(BaseModel):
47+
isFailedApproach: bool
48+
confidence: float
49+
approach: str | None = None
50+
outcome: str | None = None
51+
lesson: str | None = None
52+
53+
54+
class RerankResult(BaseModel):
55+
indices: list[int]

0 commit comments

Comments
 (0)