Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 134 additions & 54 deletions code_review_graph/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,127 @@
"""Vector embedding support for semantic code search.

Optional module — requires `pip install code-review-graph[embeddings]`.
Falls back gracefully to keyword search when not installed.
Supports multiple providers:
1. Local (sentence-transformers) - Private, fast, offline.
2. Google Gemini - High-quality, multimodal (PDF/Audio/Video), cloud-based.
"""

from __future__ import annotations

import sqlite3
import struct
import os
import hashlib
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any

from .graph import GraphNode, GraphStore, node_to_dict

# Lazy imports for optional dependencies
_model = None
_HAS_EMBEDDINGS = None
# ---------------------------------------------------------------------------
# Provider Interface and Implementations
# ---------------------------------------------------------------------------

class EmbeddingProvider(ABC):
@abstractmethod
def embed(self, texts: list[str]) -> list[list[float]]:
pass

@property
@abstractmethod
def dimension(self) -> int:
pass

def _check_available() -> bool:
"""Check if sentence-transformers is installed."""
global _HAS_EMBEDDINGS
if _HAS_EMBEDDINGS is None:
@property
@abstractmethod
def name(self) -> str:
pass


class LocalEmbeddingProvider(EmbeddingProvider):
def __init__(self):
try:
import numpy # noqa: F401
import sentence_transformers # noqa: F401
_HAS_EMBEDDINGS = True
from sentence_transformers import SentenceTransformer
self._model = SentenceTransformer("all-MiniLM-L6-v2")
except ImportError:
_HAS_EMBEDDINGS = False
return _HAS_EMBEDDINGS
raise ImportError("sentence-transformers not installed. Run: pip install code-review-graph[embeddings]")

def embed(self, texts: list[str]) -> list[list[float]]:
vectors = self._model.encode(texts, show_progress_bar=False)
return [v.tolist() for v in vectors]

@property
def dimension(self) -> int:
return 384

@property
def name(self) -> str:
return "local:all-MiniLM-L6-v2"


class GoogleEmbeddingProvider(EmbeddingProvider):
def __init__(self, api_key: str | None = None, model: str = "gemini-embedding-001"):
try:
from google import genai
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY")
if not self.api_key:
raise ValueError("GOOGLE_API_KEY environment variable not set")
self._client = genai.Client(api_key=self.api_key)
self.model = model
except ImportError:
raise ImportError("google-generativeai not installed. Run: pip install google-generativeai")

def embed(self, texts: list[str]) -> list[list[float]]:
# Google allows batching
batch_size = 100
results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = self._client.models.embed_content(
model=self.model,
contents=batch,
# Task type RETRIEVAL_DOCUMENT is best for indexing code chunks
config={"task_type": "RETRIEVAL_DOCUMENT"}
)
# The API returns a list of embeddings
results.extend([e.values for e in response.embeddings])
return results

def _get_model():
"""Lazy-load the embedding model."""
global _model
if _model is None:
from sentence_transformers import SentenceTransformer
# all-MiniLM-L6-v2: fast, 384-dim, good for code/text similarity
_model = SentenceTransformer("all-MiniLM-L6-v2")
return _model
@property
def dimension(self) -> int:
# gemini-embedding-001 is 768 by default
return 768

@property
def name(self) -> str:
return f"google:{self.model}"


def get_default_provider() -> EmbeddingProvider | None:
"""Auto-detect the best available provider."""
# Priority 1: Google (if API Key is set)
if os.environ.get("GOOGLE_API_KEY"):
try:
return GoogleEmbeddingProvider()
except Exception:
pass

# Priority 2: Local
try:
return LocalEmbeddingProvider()
except Exception:
return None


# ---------------------------------------------------------------------------
# SQLite vector storage (simple blob-based, no external vector DB)
# SQLite vector storage
# ---------------------------------------------------------------------------

_EMBEDDINGS_SCHEMA = """
CREATE TABLE IF NOT EXISTS embeddings (
qualified_name TEXT PRIMARY KEY,
vector BLOB NOT NULL,
text_hash TEXT NOT NULL
text_hash TEXT NOT NULL,
provider TEXT NOT NULL DEFAULT 'unknown'
);
"""

Expand All @@ -67,6 +139,9 @@ def _decode_vector(blob: bytes) -> list[float]:

def _cosine_similarity(a: list[float], b: list[float]) -> float:
"""Compute cosine similarity between two vectors."""
# Ensure same dimension
if len(a) != len(b):
return 0.0
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
Expand Down Expand Up @@ -95,77 +170,82 @@ class EmbeddingStore:
"""Manages vector embeddings for graph nodes in SQLite."""

def __init__(self, db_path: str | Path) -> None:
self.available = _check_available()
self.provider = get_default_provider()
self.available = self.provider is not None
self.db_path = Path(db_path)
self._conn = sqlite3.connect(str(self.db_path), timeout=30)
self._conn.row_factory = sqlite3.Row
self._conn.executescript(_EMBEDDINGS_SCHEMA)

# Migration for existing DBs missing the provider column
try:
self._conn.execute("SELECT provider FROM embeddings LIMIT 1")
except sqlite3.OperationalError:
self._conn.execute("ALTER TABLE embeddings ADD COLUMN provider TEXT NOT NULL DEFAULT 'unknown'")

self._conn.commit()

def close(self) -> None:
self._conn.close()

def embed_nodes(self, nodes: list[GraphNode], batch_size: int = 64) -> int:
"""Compute and store embeddings for a list of nodes.

Skips nodes that already have up-to-date embeddings (based on text hash).
Returns the number of newly embedded nodes.
"""
if not self.available:
"""Compute and store embeddings for a list of nodes."""
if not self.provider:
return 0

import hashlib
model = _get_model()

# Filter to nodes that need embedding
to_embed: list[tuple[GraphNode, str, str]] = []
provider_name = self.provider.name

for node in nodes:
if node.kind == "File":
continue # Skip file nodes, they don't have meaningful names
continue
text = _node_to_text(node)
text_hash = hashlib.sha256(text.encode()).hexdigest()

existing = self._conn.execute(
"SELECT text_hash FROM embeddings WHERE qualified_name = ?",
"SELECT text_hash, provider FROM embeddings WHERE qualified_name = ?",
(node.qualified_name,),
).fetchone()
if existing and existing["text_hash"] == text_hash:

# Re-embed if text changed OR provider changed
if existing and existing["text_hash"] == text_hash and existing["provider"] == provider_name:
continue
to_embed.append((node, text, text_hash))

if not to_embed:
return 0

# Batch encode
# Encode in batches
texts = [t for _, t, _ in to_embed]
vectors = model.encode(texts, batch_size=batch_size, show_progress_bar=False)
vectors = self.provider.embed(texts)

for (node, _text, text_hash), vec in zip(to_embed, vectors):
blob = _encode_vector(vec.tolist())
blob = _encode_vector(vec)
self._conn.execute(
"""INSERT OR REPLACE INTO embeddings (qualified_name, vector, text_hash)
VALUES (?, ?, ?)""",
(node.qualified_name, blob, text_hash),
"""INSERT OR REPLACE INTO embeddings (qualified_name, vector, text_hash, provider)
VALUES (?, ?, ?, ?)""",
(node.qualified_name, blob, text_hash, provider_name),
)

self._conn.commit()
return len(to_embed)

def search(self, query: str, limit: int = 20) -> list[tuple[str, float]]:
"""Search for nodes by semantic similarity.

Returns list of (qualified_name, similarity_score) sorted by score descending.
Uses chunked processing to limit peak memory usage on large graphs.
"""
if not self.available:
"""Search for nodes by semantic similarity."""
if not self.provider:
return []

model = _get_model()
query_vec = model.encode([query], show_progress_bar=False)[0].tolist()
provider_name = self.provider.name
query_vec = self.provider.embed([query])[0]

# Process in chunks to limit peak memory for large codebases
# Process in chunks
scored: list[tuple[str, float]] = []
cursor = self._conn.execute("SELECT qualified_name, vector FROM embeddings")
# Only search embeddings created with the current provider to ensure dimension match
cursor = self._conn.execute(
"SELECT qualified_name, vector FROM embeddings WHERE provider = ?",
(provider_name,)
)
chunk_size = 500
while True:
rows = cursor.fetchmany(chunk_size)
Expand All @@ -190,7 +270,7 @@ def count(self) -> int:


def embed_all_nodes(graph_store: GraphStore, embedding_store: EmbeddingStore) -> int:
"""Embed all non-file nodes in the graph. Returns count of newly embedded nodes."""
"""Embed all non-file nodes in the graph."""
if not embedding_store.available:
return 0

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ code-review-graph = "code_review_graph.cli:main"
[project.optional-dependencies]
embeddings = [
"sentence-transformers>=3.0.0",
"google-generativeai>=0.8.0",
"numpy>=1.26",
]
dev = [
Expand Down