Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ KEYWORD_TOP_K = 5
# Max entries returned by structured search (metadata filtering)
STRUCTURED_TOP_K = 5

# Convex Combination alpha for adaptive keyword boost
# S_final = α·S_semantic + (1-α)·S_keyword
CC_ALPHA = 0.7


# ============================================================================
# Database Configuration
Expand Down
81 changes: 73 additions & 8 deletions core/hybrid_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,25 @@ def __init__(
enable_reflection: bool = True,
max_reflection_rounds: int = 2,
enable_parallel_retrieval: bool = True,
max_retrieval_workers: int = 3
max_retrieval_workers: int = 3,
cc_alpha: float = None
):
self.llm_client = llm_client
self.vector_store = vector_store
self.semantic_top_k = semantic_top_k or config.SEMANTIC_TOP_K
self.keyword_top_k = keyword_top_k or config.KEYWORD_TOP_K
self.structured_top_k = structured_top_k or config.STRUCTURED_TOP_K

# Use config values as default if not explicitly provided
self.enable_planning = enable_planning if enable_planning is not None else getattr(config, 'ENABLE_PLANNING', True)
self.enable_reflection = enable_reflection if enable_reflection is not None else getattr(config, 'ENABLE_REFLECTION', True)
self.max_reflection_rounds = max_reflection_rounds if max_reflection_rounds is not None else getattr(config, 'MAX_REFLECTION_ROUNDS', 2)
self.enable_parallel_retrieval = enable_parallel_retrieval if enable_parallel_retrieval is not None else getattr(config, 'ENABLE_PARALLEL_RETRIEVAL', True)
self.max_retrieval_workers = max_retrieval_workers if max_retrieval_workers is not None else getattr(config, 'MAX_RETRIEVAL_WORKERS', 3)

# Convex Combination alpha for adaptive keyword boost
self.cc_alpha = cc_alpha if cc_alpha is not None else getattr(config, 'CC_ALPHA', 0.7)

def retrieve(self, query: str, enable_reflection: Optional[bool] = None) -> List[MemoryEntry]:
"""
Execute retrieval with planning and optional reflection
Expand Down Expand Up @@ -120,10 +124,10 @@ def _retrieve_with_planning(self, query: str, enable_reflection: Optional[bool]
# Step 5: Optional reflection-based additional retrieval
# Use override parameter if provided, otherwise use global setting
should_use_reflection = enable_reflection if enable_reflection is not None else self.enable_reflection

if should_use_reflection:
merged_results = self._retrieve_with_intelligent_reflection(query, merged_results, information_plan)

return merged_results

def _retrieve_with_reflection(self, query: str, initial_results: List[MemoryEntry]) -> List[MemoryEntry]:
Expand Down Expand Up @@ -419,7 +423,56 @@ def _merge_and_deduplicate_entries(self, entries: List[MemoryEntry]) -> List[Mem
merged.append(entry)

return merged


def _convex_combination_fusion(
self,
semantic_results: List[tuple],
keyword_results: List[tuple],
alpha: float = 0.7
) -> List[MemoryEntry]:
"""
Convex Combination (CC) fusion for hybrid retrieval.

Formula: S_final = α·S_sem + (1-α)·S_kw

Args:
semantic_results: List of (MemoryEntry, score) from semantic search
keyword_results: List of (MemoryEntry, score) from keyword search
alpha: Weight for semantic scores (default 0.7)

Returns:
List of MemoryEntry sorted by fused score
"""
semantic_scores: Dict[str, float] = {}
keyword_scores: Dict[str, float] = {}
entry_map: Dict[str, MemoryEntry] = {}

for entry, score in semantic_results:
semantic_scores[entry.entry_id] = score
entry_map[entry.entry_id] = entry

for entry, score in keyword_results:
keyword_scores[entry.entry_id] = score
if entry.entry_id not in entry_map:
entry_map[entry.entry_id] = entry

all_ids = set(semantic_scores.keys()) | set(keyword_scores.keys())

fused_scores: Dict[str, float] = {}
for entry_id in all_ids:
sem_score = semantic_scores.get(entry_id, 0.0)
kw_score = keyword_scores.get(entry_id, 0.0)

if entry_id in semantic_scores and entry_id in keyword_scores:
fused_scores[entry_id] = alpha * sem_score + (1 - alpha) * kw_score
elif entry_id in semantic_scores:
fused_scores[entry_id] = alpha * sem_score
else:
fused_scores[entry_id] = (1 - alpha) * kw_score

sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
return [entry_map[entry_id] for entry_id in sorted_ids]

def _check_answer_adequacy(self, query: str, contexts: List[MemoryEntry]) -> str:
"""
Check if current contexts are sufficient to answer the query
Expand Down Expand Up @@ -662,6 +715,7 @@ def _analyze_information_requirements(self, query: str) -> Dict[str, Any]:
2. What key entities, events, or concepts need to be identified?
3. What relationships or connections need to be established?
4. What minimal set of information pieces would be sufficient to answer this question?
5. Are there technical terms requiring exact lexical matching?

Return your analysis in JSON format:
```json
Expand All @@ -676,11 +730,20 @@ def _analyze_information_requirements(self, query: str) -> Dict[str, Any]:
}}
],
"relationships": ["relationship1", "relationship2", ...],
"minimal_queries_needed": 2
"minimal_queries_needed": 2,
"exact_match_terms": [],
"use_keyword_boost": false
}}
```

Focus on identifying the minimal essential information needed, not exhaustive details.
For exact_match_terms, include ONLY terms requiring exact lexical matching:
- Function/method names: parseJWT, get_user_id
- Error codes: ECONNREFUSED, CVE-2017-3156
- Version numbers: v2.1.0, Oracle 12c
- File names: config.yaml, .env

Set use_keyword_boost=true ONLY if exact_match_terms is non-empty.
For conversational queries about people/events, leave both fields as defaults.

Return ONLY the JSON, no other text.
"""
Expand Down Expand Up @@ -713,7 +776,9 @@ def _analyze_information_requirements(self, query: str) -> Dict[str, Any]:
"key_entities": [query],
"required_info": [{"info_type": "general", "description": "relevant information", "priority": "high"}],
"relationships": [],
"minimal_queries_needed": 1
"minimal_queries_needed": 1,
"exact_match_terms": [],
"use_keyword_boost": False
}

def _generate_targeted_queries(self, original_query: str, information_plan: Dict[str, Any]) -> List[str]:
Expand Down
52 changes: 52 additions & 0 deletions database/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,55 @@ def clear(self):
self._fts_initialized = False
self._init_table()
print("Database cleared")

def keyword_search_with_scores(self, keywords: List[str], top_k: int = 3) -> List[tuple]:
"""
Keyword search using BM25 FTS that returns (MemoryEntry, score) tuples.

Uses LanceDB native Full-Text Search with BM25 ranking.
Scores are normalized to [0, 1] range.

Returns:
List of (MemoryEntry, float) tuples sorted by score (highest first)
"""
try:
if not keywords or self.table.count_rows() == 0:
return []

query = " ".join(keywords)
results = self.table.search(query).limit(top_k).to_list()

if not results:
return []

scored_entries = []
max_score = 0

for r in results:
score = r.get("_score", 0.0)
max_score = max(max_score, score)
try:
entry = MemoryEntry(
entry_id=r["entry_id"],
lossless_restatement=r["lossless_restatement"],
keywords=list(r.get("keywords") or []),
timestamp=r.get("timestamp") or None,
location=r.get("location") or None,
persons=list(r.get("persons") or []),
entities=list(r.get("entities") or []),
topic=r.get("topic") or None
)
scored_entries.append((entry, score))
except Exception as e:
print(f"Warning: Failed to parse FTS result: {e}")
continue

# Normalize scores to [0, 1]
if max_score > 0:
scored_entries = [(entry, score / max_score) for entry, score in scored_entries]

return scored_entries

except Exception as e:
print(f"Error during keyword search with scores: {e}")
return []